prompting.validators.reward.nsfw
#
Module Contents#
Classes#
- class prompting.validators.reward.nsfw.NSFWRewardEvent#
- class prompting.validators.reward.nsfw.NSFWRewardModel(device)#
Bases:
prompting.validators.reward.reward.BaseRewardModel
- Parameters:
device (str) –
- nsfw_filter_model_path = 'facebook/roberta-hate-speech-dynabench-r4-target'#
- reward(prompt, completion, name)#
- Parameters:
- Return type:
- get_rewards(prompt, completions, name)#
- Parameters:
- Return type:
List[NSFWRewardEvent]
- normalize_rewards(rewards)#
- Parameters:
rewards (torch.FloatTensor) –
- Return type:
torch.FloatTensor