0. Abstract
We consider an alternative approach: converting feedback to instruction by relabeling the original one and training the model for better alignment in a supervised manner
1. Introduction
Human alignment를 위해 두 가지 정도의 방향성이 있음
- Proximal Policy Optimization (PPO): rather complex, sensitive to hyperparameters, and requires additional training in the reward model and value network
- imitation learning: less data-effective as it only makes use of the success instruction-output pairs, completely abandoning the ones that do not align
Hindsight Instruction Relabeling(HIR) adopts the central idea of relabeling the instructions in a hindsight fashion. HIR alternates between two phases:
- an online sampling phase to generate a dataset of instruction-output pairs
- offline learning phase that relabels the instructions of each pair and performs standard supervised learning
2. Hindsight Instruction Relabeling
We can formulate the language model alignment as a goal-conditioned RL problem.
(1) Instruction Following as Goal-conditioned RL
A language model $\mathcal{M}$ can take instructional prompt $\textbf{p}$ and initial query token sequence $\textbf{q} = \{\text{q}_{0},...,\text{q}_{i}\}$ as input, and autoregressively predict next token $\text{e}_{i+1} = \mathcal{M}\left( \textbf{p}, \textbf{q}, \{\text{e}_{0},...,\text{e}_{i}\} \right)$.
We can view standard prompt-conditioned language tasks (e.g. multi-step reasoning) as a goal-reaching problem.
- Goal Space $\mathcal{G}$: space of instructional prompt $\textbf{p}$
- State space $\mathcal{S}$: space of input token sequence $\textbf{q} \cup \{\text{e}_{i} \}$
- Action space $\mathcal{A}$: space of output token $\text{e}_{i+1}$
- Transition probability $\mathcal{P}$: $ \mathcal{M}\left( \text{e}_{i+1} | \textbf{p}, \textbf{q}, \{\text{e}_{0},...,\text{e}_{i}\}\right) $
- Reward $\mathcal{R}$: alignment score of $\{\text{e}_{0},...,\text{e}_{i+1}\}$ with instruction $\textbf{p}$ and query \textbf{q}, from human or scripted feedback, which is not used in HIR.
Here all $\mathcal{G}, \mathcal{S}, \mathcal{A}$ are space of token embeddings, but $\mathcal{G}$ corresponds to instructional prompts, while $\mathcal{S}, \mathcal{A}$ corresponds to model inputs and outputs. In this way, we can also view the language model as a goal-conditioned policy:
$$\pi \models \mathcal{M}\left( \text{e}_{i+1} | \textbf{p}, \textbf{q}, \{\text{e}_{0},...,\text{e}_{i}\} \right)$$
(2) Algorithm Overview
i. Online Sampling
Given instruction $\textbf{p}$ and query \textbf{q}, we use $τ = 1$ to get the output sequence $\textbf{o} = \{\text{e}_{0},\text{e}_{1},...,\text{e}_{L}\}$, which gives us the online replay dataset $\mathcal{D}_{\text{online}}$.
$$\mathcal{D}_{\text{online}} = \bigcup_{ i=1}^{N}\left\{ \textbf{p}_{i},\textbf{q}_{i},\textbf{o}_{i} \right\}$$
ii. Offline Relabeling
For every instruction-output pair $(\textbf{p},\textbf{q},\textbf{o})$ that are not necessarily aligned, we relabel this pair with a new instruction that can align with the outcome of the model $(\textbf{p}^{*},\textbf{q},\textbf{o})$.
The new instruction $\textbf{p}^{*}$ is generated based on the feedback function $\mathcal{R}(\textbf{p},\textbf{q},\textbf{o})$ and the instruction generation function $\phi(\textbf{p},\textbf{q},\textbf{o}, \textbf{r})$, which can either be learned or scripted. For simplicity, $\phi$ is also scripted based on the correctness of the reasoning outcome
(3) Instruction Relabeling
Conduct instruction relabeling at intermediate time steps on the generated sub-output.
i. Sub-output Relabeling
It is important to sample partial outputs and relabel the instruction. In this way, we could give more dense feedback through instruction relabeling.
Consider we relabel the $i-$th time step. The input to the model is $\textbf{q} \cup \{\text{e}_{0},...,\text{e}_{i-1}\}$. We can edit the instruction as a future goal based on the future alignment score:
$$\textbf{p}^{*} = \phi\left( \textbf{p}, \textbf{q}, \{\text{e}_{0},...,\text{e}_{L}\}, \mathcal{R}\left( \textbf{p}, \textbf{q}, \{\text{e}_{0},...,\text{e}_{L}\} \right) \right)$$
where $\phi$ and $\mathcal{R}$ are the instruction generation function and feedback function.
ii. Contrastive Instruction Follwoing
Suppose $\textbf{o}_{i} = \mathcal{M}(\textbf{q}_{i}, \textbf{p}_{i})$. Given the log probability of $\textbf{o}_{i}$ conditioned on $\textbf{q}_{k}, \textbf{p}_{k}$ as:
$$\mathcal{P}_{ik} = \text{log}P_{\mathcal{M}}(\textbf{o}_{i}|\textbf{q}_{k}, \textbf{p}_{k})$$
We define the following contrastive loss:
$$\mathcal{L}_{contrastive} = -\sum_{i=1}^{n}\text{log}\frac{\text{exp}(\mathcal{P}_{ii} )}{\sum_{k-1}^{n}\text{exp}(\mathcal{P}_{ik} )}$$
This helps to avoid the model learning the behavior that maps the same output for different instructions.
iii. Entropy Regularization
As a common practice in RL, we apply entropy regularization to the output given a particular instruction. This negative entropy term ensures the sampling phase won’t converge too early for better exploration.
$$\mathcal{L}_{entropy} = \sum_{i=1}^{n}\mathcal{P}_{k}\text{log}\mathcal{P}_{k}$$
3. Experiments
'논문 및 개념 정리' 카테고리의 다른 글
[transformers] Scaled Dot Product Attention (0) | 2024.06.26 |
---|---|
Vector Outer Product (0) | 2024.06.26 |
Hold-out vs Cross-validation 차이 (0) | 2023.07.16 |
Propensity Score (0) | 2023.06.28 |
[2021] (FLAN)FINETUNED LANGUAGE MODELS ARE ZERO-SHOT LEARNERS (Instruction-Tuning 논문) (2) | 2023.03.13 |