prompting.validators.reward.relevance
#
Module Contents#
Classes#
Functions#
|
Applies mean pooling to the token embeddings generated by the model. |
- prompting.validators.reward.relevance.mean_pooling(model_output, attention_mask)#
Applies mean pooling to the token embeddings generated by the model. :param model_output: Embedding model output, where the first element contains token embeddings. :type model_output: torch.Tensor :param attention_mask: Attention mask to indicate valid tokens. :type attention_mask: torch.Tensor
- Returns:
Mean-pooled representation of the token embeddings.
- Return type:
Notes
The function calculates the mean-pooled representation using the attention mask for valid tokens.
Input_mask_expanded is created by expanding the attention mask to match the size of token embeddings.
- The result is obtained by summing the element-wise multiplication of embeddings and input_mask_expanded,
and dividing it by the sum of input_mask_expanded after clamping its values to a minimum of 1e-9.
- class prompting.validators.reward.relevance.RelevanceRewardEvent#
- class prompting.validators.reward.relevance.RelevanceRewardModel(device)#
Bases:
prompting.validators.reward.reward.BaseRewardModel
- Parameters:
device (str) –
- get_rewards(prompt, completions, name)#
- Parameters:
- Return type:
List[RelevanceRewardEvent]
- normalize_rewards(rewards)#
This method normalizes the given rewards by updating the moving mean and variance statistics. The rewards are first standardized, and then scaled to the 0-1 range using a cumulative distribution function (CDF) to ensure they’re in a comparable range across different environments.
Args: rewards (torch.FloatTensor): The reward values to be normalized.
Returns: torch.FloatTensor: The normalized reward values.
Note: - This function uses Welford’s online algorithm to update the mean and variance. - It standardizes the reward values using the updated mean and variance. - It then scales the standardized values to the 0-1 range using the error function (erf) as a CDF.
- Parameters:
rewards (torch.FloatTensor) –
- Return type:
torch.FloatTensor
- reward(prompt, completion, name)#
- Parameters:
- Return type:
- class prompting.validators.reward.relevance.BertRelevanceRewardModel(device)#
Bases:
prompting.validators.reward.reward.BaseRewardModel
- Parameters:
device (str) –
- relevance_model_path = 'bert-base-uncased'#
- class prompting.validators.reward.relevance.MpnetRelevenceModel(device)#
Bases:
prompting.validators.reward.reward.BaseRewardModel
- Parameters:
device (str) –
- diversity_model_path = 'sentence-transformers/all-mpnet-base-v2'#