RLHF on Nutanix Cloud Platform

Nutanix.dev-RLHFonNutanixCloudPlatform

Table of Contents

Introduction 

The goal of this article is to show how customers can use a reinforcement learning on human feedback (RLHF) workflow to finetune a large language model (LLM)  from scratch using open source Python® libraries on the Nutanix Cloud Platform™ (NCP) HCI solution. RLHF is increasingly being used over vanilla supervised finetuning or instruction tuning because of its ability to generalize better for dynamic datasets with frequent distribution shifts. Most chat models such as “meta-llama/Llama-2-70b-chat” use RLHF for downstream task-specific adaptation. We presume that a deep understanding of RLHF aids in better decision-making for LLM downstream task adaptation. In general, RLHF with proximal policy optimization (PPO) is the key  technique used in the ChatGPT, Anthropic, and Llama™ Chat models for downstream adaptation. Essentially, RLHF replaces static finetuning data with human feedback which is often more intuitive to generate and captures real human assessment on helpfulness and safety. The compute infrastructure provisioning for RLHF can be a complex endeavor, especially in consideration of data privacy, data sovereignty and data governance. The breakneck speed of AI innovation around model and training dataset availability and compute optimization trigger rapid enterprise AI adoption. This gold rush of enterprise AI adoption demands infrastructure solutions that can scale easily, securely and robustly with teams trained to manage more traditional workloads. In this article, we show how NCP can be leveraged to perform RLHF workflow on a GPT-2 based policy network and Bert based reward network on IMDB dataset. We validate that RLHF tunes the policy network and thereby improves the average positive sentiment score of the output from the policy network by a significant margin.  

RLHF Workflow

The goal of  RLHF is to alter the behavior of a supervised policy, essentially a LLM model in this context, based on human feedback. Based on their preferences, human annotators assign their valuations to LLM predictions. RLHF essentially imposes the human preferences to tune the LLM parameters to adapt to human preferences.  Reinforcement learning deals with a goal-seeking agent interacting with an environment to maximize a long-term cumulative reward score by choosing appropriate actions. In RLHF, the role of an RL agent is played by a policy optimization algorithm such as PPO which tunes the behavior of an LLM model by updating its parameters. The role of the environment is played by a reward network built on a compilation of human feedback. Figure 1 shows a schematic representation of an RLHF workflow.

Figure 1: Schematic representation of the RLHF workflow in action

An RLHF workflow involves following steps. 

  1. Collect demonstration data and train a supervised policy
    • Compile a collection of (sample, desired response) pairs
    • Fine tune a pre-trained model such as GPT-3.5 or Llama-7b. 

In most enterprise use cases, a pre-trained or supervised fine-tuned model is selected. 

  1. Collect comparison data and train a reward model
    • Compile a prompt and several model outputs
    • Annotators ranks the outputs from best to worst
    • Train a reward model based on the augmented data with human preferences
  1. Optimize a policy against the reward model using the PPO-based RL algorithm. This is essentially the RLHF loop.
    • A new prompt is sampled from the dataset
    • The PPO model is initialized from the supervised policy
    • The policy generates an output
    • The reward model calculates a reward for the output
    • The reward is used to update the policy using PPO

This article walks through the implementation of these different steps from scratch. Before jumping into the implementation, we will discuss the Nutanix infrastructure stack for running the AI workloads. In an RLHF workflow, PPO plays the central role of a learning algorithm. PPO is an on-policy learning algorithm. It means as it explores by sampling actions according to the latest version of the policy. The discrete action space for the PPO consists of a ranking of different answers generated by the LLM. On the other hand, the usual state space of PPO consists of billions of LLM model parameters.  

Nutanix Cloud Platform

At Nutanix, we are dedicated to enabling customers to build and deploy intelligent applications anywhere – edge, core data centers, service provider infrastructure, and public clouds. Figure 2 shows how AI/ML is integrated into the core Nutanix® infrastructure layer. RLHF is a key workflow in the development of AI Applications at the top.

Figure 2: AI stack running on the cloud-native infrastructure stack of NCP. The stack provides holistic integration between supporting cloud-native infrastructure layer, including chip layer, followed by virtual machine layer, supporting library/tooling layer, and AI stack layer, including Foundation Models (different variants of transformers), task specific AI app layers.

As shown in Figure 2, the App layer runs on the top of the infrastructure layer. The infrastructure layer can be deployed in two steps, starting with Prism Element™  login followed by VM resource configuration. Figure 3 shows the UI for Prism Element.

Figure 3: The UI showing the setup for a Prism Element on which the transformer model for this article was trained. It shows the AHV® hypervisor summary, storage summary, VM summary, hardware summary, monitoring for cluster-wide controller IOPS, monitoring for cluster-wide controller I/O bandwidth, monitoring for cluster-wide controller latency, cluster CPU usage, cluster memory usage, granular health indicators, and data resiliency status.

After logging into Prism Element, we create a virtual machine (VM) hosted on our Nutanix AHV® cluster. As shown in  Figure 4, the VM has following resource configuration settings: 22.04 Ubuntu® operating system, 16 single core vCPUs, 64 GB of RAM, and NVIDIA® A100 tensor core passthrough GPU with 40 GB memory. The GPU is installed with the NVIDIA RTX 15.0 driver for Ubuntu OS (NVIDIA-Linux-x86_64-525.60.13-grid.run). The large deep learning models with transformer architecture require GPU or other compute accelerators with high memory bandwidth, large registers and L1 memory. 

Figure 4: The VM resource configuration UI pane on Nutanix Prism Element. As shown, it helps a user configure the number of vCPU(s), the number of cores per vCPUs, memory size (GiB), and GPU choice. We used an NVIDIA A100 80G for this article.

The NVIDIA A100 Tensor Core GPU is designed to power the world’s highest-performing elastic datacenters for AI, data analytics, and HPC. Powered by the NVIDIA Ampere™ architecture, A100 is the engine of the NVIDIA data center platform. A100 provides up to 20X higher performance over the prior generation and can be partitioned into seven GPU instances to dynamically adjust to shifting demands.

To peek into the detailed features of A100 GPU, we run  `nvidia-smi` command which is a command line utility, based on top of the NVIDIA Management Library (NVML), intended to aid in the management and monitoring of NVIDIA GPU devices. The output of the `nvidia-smi` command is shown in Figure 6. It shows the Driver Version to be 515.86.01 and CUDA version to be 11.7. Figure 5 shows several critical features of the A100 GPU we used. The details of these features are described in Table 1.

Figure 5: Output of `nvidia-smi` for the underlying A100 GPU
FeatureValueDescription
GPU0GPU Index
Name NVIDIA A100GPU Name
Temp34 CCore GPU Temperature
PerfP0GPU Performance
Persistence-MOnPersistence Mode
Pwr: Usage/Cap36W / 250WGPU Power Usage and it capability
Bus Id00000000:00:06.0domain:bus:device.function
Disp. AOffDisplay Active
Memory-Usage25939MiB / 40960MiBMemory allocation out of total memory
Volatile Uncorr. ECC0Counter of uncorrectable ECC memory error
GPU-Util0%GPU Utilization
Compute M.DefaultCompute Mode
MIG M.DisabledMulti-Instance Mode
Table 1: Description of the key features of the underlying A100 GPU.

Implementation of RLHF Workflow

This section walks through the implementation of the RLHF workflow. We are testing whether the RLHF loop increases the performance of the GPT-2 model by a statistically significant margin. 

  1. Import of Necessary Python Libraries
import torch
from tqdm import tqdm
import pandas as pd

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
  1. Set proximal policy optimization (PPO) configurations
config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
)

sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

We use “lvwerra/gpt2-imdb” model which is a GPT-2 model fine-tuned on IMDB dataset. The learning rate used us 1.41e-5.  Essentially,  “lvwerra/gpt2-imdb” is the policy network.

  1. Build data loader from IMDB dataset
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build a dataset for training.

    Args:
    dataset_name (`str`):
    The name of the dataset to be loaded.

    Returns:
    dataloader (`torch.utils.data.DataLoader`):
    The data loader for the dataset.
    """

    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

input_size = LengthSampler(input_min_text_length, input_max_text_length)


def tokenize(sample):

    sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
    sample["query"] = tokenizer.decode(sample["input_ids"])
    return sample

ds = ds.map(tokenize, batched=False)
ds.set_format(type="torch")

return ds
  1. Load Model, Reference Model, and Tokenizer Locally
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

We use “AutoModelForCausalLMWithValueHead” module from “trl” for the model loading with a value head. We load the model twice: first to optimize with PPO and second to serve as a reference to calculate the KL divergence from the starting point. This serves as a secondary reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model. 

  1. Initialize PPOTrainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
  1. Initialize the Reward Model
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

For reward model, we use “lvwerra/distilbert-imdb” model which is a fine-tuned version of distilbert-base-uncased. This model assigns a score to the outcome from the policy network. Essentially, it is a proxy of a human evaluator. 

  1. RLHF training loop

At this point, we are ready to run the entire RLHF workflow. We have a GPT-2 based policy network to get the query response, we use BERT to get sentiments for (query, response) pairs, and  PPO model to optimize policy with (query, response, reward) triplet. With RLHF, we want to improve the reward score meaning we want to generate responses with more positive sentiments. The following coding block covers the entire training loop:

It took nearly 3 hr to run the model on an A100 40GB on the Nutanix infrastructure as it is described previously. During the training we see the model reward keeps increasing and reaches a plateau after ~25 training cycles, as shown in Figure 5. 

  1. RLHF  Evaluation

The RLHF loop is evaluated on 15 questions as shown on the following screenshot. As we can see, the PPO based RLHF loop increased the reward 14/15 times. The negative outcome on 7  can be attributed to the sparse one-word query. In aggregate, we see PPO based RLHF training increases the average reward by 30x and the median reward by 7x. It is a significant improvement in the reward score, validating the effectiveness of RLHF in LLM adaptation. 

Figure 6: PPO based training in action: Increasing value of mean reward from the policy network with training steps.

 

 Key Takeaways

  1. We trained an RLHF loop end-to-end from scratch with PPO algorithm with IMDB movie database. 
  2. The trained RLHF shows significant improvement on reward score meaning it is successful in LLM behavior adaptation. 
  3. Finally, it is worth mentioning that in this article, we replace direct human feedback with a Bert based reward network. 

© 2023 Nutanix, Inc. All rights reserved. Nutanix, the Nutanix logo and all Nutanix product, feature and service names mentioned herein are registered trademarks or trademarks of Nutanix, Inc. in the United States and other countries. Other brand names mentioned herein are for identification purposes only and may be the trademarks of their respective holder(s). This post may contain links to external websites that are not part of Nutanix.com. Nutanix does not control these sites and disclaims all responsibility for the content or accuracy of any external site. Our decision to link to an external site should not be considered an endorsement of any content on such a site. Certain information contained in this post may relate to or be based on studies, publications, surveys and other data obtained from third-party sources and our own internal estimates and research. While we believe these third-party studies, publications, surveys and other data are reliable as of the date of this post, they have not independently verified, and we make no representation as to the adequacy, fairness, accuracy, or completeness of any information obtained from third-party sources.

This post may contain express and implied forward-looking statements, which are not historical facts and are instead based on our current expectations, estimates and beliefs. The accuracy of such statements involves risks and uncertainties and depends upon future events, including those that may be beyond our control, and actual results may differ materially and adversely from those anticipated or implied by such statements. Any forward-looking statements included herein speak only as of the date hereof and, except as required by law, we assume no obligation to update or otherwise revise any of such forward-looking statements to reflect subsequent events or circumstances

© 2024 Nutanix, Inc. All rights reserved. Nutanix, the Nutanix logo and all Nutanix product, feature and service names mentioned herein are registered trademarks or trademarks of Nutanix, Inc. in the United States and other countries. Other brand names mentioned herein are for identification purposes only and may be the trademarks of their respective holder(s). This post may contain links to external websites that are not part of Nutanix.com. Nutanix does not control these sites and disclaims all responsibility for the content or accuracy of any external site. Our decision to link to an external site should not be considered an endorsement of any content on such a site. Certain information contained in this post may relate to or be based on studies, publications, surveys and other data obtained from third-party sources and our own internal estimates and research. While we believe these third-party studies, publications, surveys and other data are reliable as of the date of this post, they have not independently verified, and we make no representation as to the adequacy, fairness, accuracy, or completeness of any information obtained from third-party sources.