prompting.validators.reward#

Submodules#

Package Contents#

Classes#

Blacklist

NSFWRewardModel

DirectPreferenceRewardModel

OpenAssistantRewardModel

ReciprocateRewardModel

RelevanceRewardModel

BaseRewardModel

DahoasRewardModel

DiversityRewardModel

PromptRewardModel

RewardModelType

Create a collection of name/value pairs.

DefaultRewardFrameworkConfig

Reward framework default configuration.

class prompting.validators.reward.Blacklist(boundary=6, n_min=5, n_max=14, word_limit=2000, A=1.3, preprocess='[^(\\w|\\s)]', partial_ratio_boundary=95, half_life=20000, support=0.01, error=0.001, memory_lim=1000000, frequency_multiplier=100)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:
  • boundary (float) –

  • n_min (int) –

  • n_max (int) –

  • word_limit (int) –

  • A (float) –

  • preprocess (str) –

  • partial_ratio_boundary (float) –

  • half_life (int) –

  • support (float) –

  • error (float) –

  • memory_lim (int) –

  • frequency_multiplier (float) –

property name: str#
Return type:

str

add(texts)#

Extract and add n-grams from a list of texts to counter

Parameters:

texts (list) – batch of completion texts

extract_ngrams(text)#

Extract n-grams from text string

Parameters:

text (str) – completion text

Returns:

List of n-gram tuples

Return type:

list

_add_ngrams(ngrams)#

Adds n-grams to counter, removing old n-grams periodically. Counting and pruning method based on Lossy counter. Reference: https://files.ifi.uzh.ch/dbtg/sdbs13/T01.3.pdf

Parameters:

ngrams (List[tuple]) – List of n-gram tuples

prune()#

Prune the counter when the count is smaller then bucket index.

reset()#

Reset counters to initial values.

calculate_significance()#

Calculate significance of all n-grams in counter. By construction, n-grams with count 1 will have significance 0.

Returns:

Dictionary of n-gram tuples and their significance scores

Return type:

dict

get_significance()#

Get significance scores, only recalculating if the counter has been updated.

Returns:

Dictionary of n-gram tuples and their significance scores

Return type:

dict

most_common(n=10)#

Get most common n-grams in queue

Parameters:

n (int) – Number of most common n-grams to return. Defaults to 10.

Returns:

Sorted dictionary of n-gram tuples and their counts

Return type:

dict

most_significant(n=10, force_update=True)#

Get most significant n-grams in queue based on significance scores

Parameters:
  • n (int, optional) – Number of most significant n-grams to return. Defaults to 10.

  • force_update (bool, optional) – Force recalculate the significance scores. Defaults to True.

Returns:

Sorted dictionary of n-gram tuples and their significance scores

Return type:

dict

set_counter_to_half()#

Set all the counters to half for a rolling window effect.

reward(prompt, completion, name)#

Reward function for blacklist reward model. Returns 1 if completion contains an n-gram with significance above the boundary, 0 otherwise.

Parameters:
  • prompt (str) – Prompt text

  • completion (str) – Completion text

  • name (str) – Name of the validation step

Returns:

Reward value {0,1}

Return type:

float

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[BlacklistRewardEvent]

normalize_rewards(rewards)#

This method normalizes the given rewards by updating the moving mean and variance statistics. The rewards are first standardized, and then scaled to the 0-1 range using a cumulative distribution function (CDF) to ensure they’re in a comparable range across different environments.

Args: rewards (torch.FloatTensor): The reward values to be normalized.

Returns: torch.FloatTensor: The normalized reward values.

Note: - This function uses Welford’s online algorithm to update the mean and variance. - It standardizes the reward values using the updated mean and variance. - It then scales the standardized values to the 0-1 range using the error function (erf) as a CDF.

Parameters:

rewards (torch.FloatTensor) –

Return type:

torch.FloatTensor

class prompting.validators.reward.NSFWRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

nsfw_filter_model_path = 'facebook/roberta-hate-speech-dynabench-r4-target'#
reward(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

NSFWRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[NSFWRewardEvent]

normalize_rewards(rewards)#
Parameters:

rewards (torch.FloatTensor) –

Return type:

torch.FloatTensor

class prompting.validators.reward.DirectPreferenceRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

reward_model_name: str = 'cerebras/btlm-3b-8k-base'#
reward_single(prompt, completion, name, with_penalty=True)#

Calculates a direct preference optimization (DPO) style reward for a completion, which is a reference model’s average log-probability for completion tokens given a prompt. Uses guidance from eric-mitchell/direct-preference-optimization.

Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

prompting.validators.reward.reward.BaseRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[prompting.validators.reward.reward.BaseRewardEvent]

class prompting.validators.reward.OpenAssistantRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

reward_model_name: str = 'OpenAssistant/reward-model-deberta-v3-large-v2'#
reward_single(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

prompting.validators.reward.reward.BaseRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[prompting.validators.reward.reward.BaseRewardEvent]

class prompting.validators.reward.ReciprocateRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

reward_model_path: str = 'reciprocate/gpt-j_rm_format-oa'#
revision: str = '501f895'#
reward(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

prompting.validators.reward.reward.BaseRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[prompting.validators.reward.reward.BaseRewardEvent]

class prompting.validators.reward.RelevanceRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[RelevanceRewardEvent]

normalize_rewards(rewards)#

This method normalizes the given rewards by updating the moving mean and variance statistics. The rewards are first standardized, and then scaled to the 0-1 range using a cumulative distribution function (CDF) to ensure they’re in a comparable range across different environments.

Args: rewards (torch.FloatTensor): The reward values to be normalized.

Returns: torch.FloatTensor: The normalized reward values.

Note: - This function uses Welford’s online algorithm to update the mean and variance. - It standardizes the reward values using the updated mean and variance. - It then scales the standardized values to the 0-1 range using the error function (erf) as a CDF.

Parameters:

rewards (torch.FloatTensor) –

Return type:

torch.FloatTensor

reward(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

RelevanceRewardEvent

class prompting.validators.reward.BaseRewardModel#
abstract property name: str#
Return type:

str

__str__()#

Return str(self).

Return type:

str

__repr__()#

Return repr(self).

Return type:

str

abstract get_rewards(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (List[str]) –

  • name (str) –

Return type:

Union[torch.FloatTensor, dict]

normalize_rewards(rewards)#

This method normalizes the given rewards by updating the moving mean and variance statistics. The rewards are first standardized, and then scaled to the 0-1 range using a cumulative distribution function (CDF) to ensure they’re in a comparable range across different environments.

Args: rewards (torch.FloatTensor): The reward values to be normalized.

Returns: torch.FloatTensor: The normalized reward values.

Note: - This function uses Welford’s online algorithm to update the mean and variance. - It standardizes the reward values using the updated mean and variance. - It then scales the standardized values to the 0-1 range using the error function (erf) as a CDF.

Parameters:

rewards (torch.FloatTensor) –

Return type:

torch.FloatTensor

apply(prompt, responses, name)#

Applies the reward model across each call. Unsuccessful responses are zeroed.

Parameters:
Return type:

Union[torch.FloatTensor, dict]

class prompting.validators.reward.DahoasRewardModel(path, device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:
  • path (str) –

  • device (str) –

property name: str#
Return type:

str

model_name = 'EleutherAI/gpt-j-6b'#
static load_weights(path)#
Parameters:

path (str) –

reward(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

prompting.validators.reward.reward.BaseRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[prompting.validators.reward.reward.BaseRewardEvent]

forward(input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, labels=None, return_dict=False, output_attentions=False, output_hidden_states=False)#
class prompting.validators.reward.DiversityRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

diversity_model_path = 'sentence-transformers/all-mpnet-base-v2'#
get_embeddings(sentences)#

Runs a forward pass through the model. :param sentences: text message to be encoded. :type sentences: List[str]

Returns:

Embedding for the message.

Return type:

embedding (torch.FloatTensor)

Parameters:

sentences (List[str]) –

update_historic_embeddings(embeddings)#
Parameters:

embeddings (torch.FloatTensor) –

get_historic_rewards(embeddings)#
Parameters:

embeddings (torch.FloatTensor) –

Return type:

torch.FloatTensor

get_batch_rewards(embeddings)#
Parameters:

embeddings (torch.FloatTensor) –

Return type:

torch.FloatTensor

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[DiversityRewardEvent]

normalize_rewards(raw_rewards)#

This method normalizes the given rewards by updating the moving mean and variance statistics. The rewards are first standardized, and then scaled to the 0-1 range using a cumulative distribution function (CDF) to ensure they’re in a comparable range across different environments.

Args: rewards (torch.FloatTensor): The reward values to be normalized.

Returns: torch.FloatTensor: The normalized reward values.

Note: - This function uses Welford’s online algorithm to update the mean and variance. - It standardizes the reward values using the updated mean and variance. - It then scales the standardized values to the 0-1 range using the error function (erf) as a CDF.

Parameters:

raw_rewards (torch.FloatTensor) –

Return type:

torch.FloatTensor

class prompting.validators.reward.PromptRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

reward_model_name: str = 'VMware/open-llama-7b-open-instruct'#
reward(prompt, completion, name)#
Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

prompting.validators.reward.reward.BaseRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[prompting.validators.reward.reward.BaseRewardEvent]

class prompting.validators.reward.RewardModelType(*args, **kwds)#

Bases: enum.Enum

Create a collection of name/value pairs.

Example enumeration:

>>> class Color(Enum):
...     RED = 1
...     BLUE = 2
...     GREEN = 3

Access them by:

  • attribute access:

    >>> Color.RED
    <Color.RED: 1>
    
  • value lookup:

    >>> Color(1)
    <Color.RED: 1>
    
  • name lookup:

    >>> Color['RED']
    <Color.RED: 1>
    

Enumerations can be iterated over, and know how many members they have:

>>> len(Color)
3
>>> list(Color)
[<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]

Methods can be added to enumerations, and members can have their own attributes – see the documentation for details.

dpo = 'dpo_reward_model'#
rlhf = 'rlhf_reward_model'#
reciprocate = 'reciprocate_reward_model'#
dahoas = 'dahoas_reward_model'#
diversity = 'diversity_reward_model'#
prompt = 'prompt_reward_model'#
blacklist = 'blacklist_filter'#
nsfw = 'nsfw_filter'#
relevance = 'relevance_filter'#
relevance_bert = 'relevance_bert'#
relevance_mpnet = 'relevance_mpnet'#
task_validator = 'task_validator_filter'#
keyword_match = 'keyword_match_penalty'#
class prompting.validators.reward.DefaultRewardFrameworkConfig#

Reward framework default configuration. Note: All the weights should add up to 1.0.

dpo_model_weight: float = 0.6#
rlhf_model_weight: float = 0#
reciprocate_model_weight: float = 0.4#
dahoas_model_weight: float = 0#
prompt_model_weight: float = 0#