def compute_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor,
index: np.ndarray, epsilon: float = 1e-6,
):
"""
Computes advantage.
"""
# 0. Define Arguments
with torch.no_grad():
bsz = scores.shape[0]
# 1. Group scores by index
for i in range(bsz):
id2score[index[i]].append(scores[i])
# 2. Calculate statistics for each group
for idx in id2score:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor(id2score[idx]))
# 3. Normalize scores
for i in range(bsz):
# GRPO: Normalize by Standard Deviation
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
# MaxRL: Normalize by Pass Rate
scores[i] = (scores[i] - id2mean[index[i]]) / (id2mean[index[i]] + epsilon)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores