prompting.validators.gating
#
Module Contents#
Classes#
This class is an abstract base class for the gating model. It defines the interface for the gating model. |
|
This class is a PyTorch module that encapsulates the gating model functionality. |
|
This class is a PyTorch module that encapsulates a custom version of a gating model based on sentence transformers. |
- class prompting.validators.gating.BaseGatingModel#
Bases:
torch.nn.Module
,abc.ABC
This class is an abstract base class for the gating model. It defines the interface for the gating model.
- classmethod add_args(parser)#
Adds command line arguments to the parser that are used to configure the gating model. The arguments added are: - –gating.model_name: Name of the pre-trained transformer-based language model to use as the encoding layer
for the gating model. (default: ‘EleutherAI/gpt-neo-125m’)
–gating.num_uids: Number of uids to gate on. (default: 4096)
–gating.learning_rate: Learning rate for the gating model optimizer. (default: 0.01)
–gating.momentum: Momentum for the gating model optimizer. (default: 0.9)
- Parameters:
parser (argparse.ArgumentParser) –
- abstract forward(message)#
Forward pass through the gating model
- Parameters:
message (str) –
- Return type:
torch.FloatTensor
- abstract backward(scores, rewards)#
Backward pass through the gating model
- Parameters:
scores (torch.FloatTensor) –
rewards (torch.FloatTensor) –
- abstract resync(previous_metagraph, metagraph)#
Resync the gating model with the latest state of the network Args: previous_metagraph (:obj: bt.metagraph.Metagraph):
Previous state of metagraph before updated resync
- metagraph (:obj: bt.metagraph.Metagraph):
Latest state of the metagraph with updated uids and hotkeys
- Parameters:
previous_metagraph (bittensor.metagraph.Metagraph) –
metagraph (bittensor.metagraph.Metagraph) –
- classmethod config()#
Returns a configuration object that contains the command line arguments for the gating model.
- classmethod check_config(config)#
Validates the configuration object for the gating model.
- Parameters:
config (bittensor.Config) –
- class prompting.validators.gating.GatingModel(metagraph, config=None, model_name=None, num_uids=None)#
Bases:
BaseGatingModel
This class is a PyTorch module that encapsulates the gating model functionality.
The backward method runs a backward pass through the model using the mean squared error between
the normalized scores and the normalized rewards as the loss function.
The forward method runs a forward pass through the model, encoding the input message and generating scores
for each uid in the network. The scores are returned as a tensor.
- Parameters:
metagraph (bittensor.metagraph.Metagraph) –
config (bittensor.config) –
model_name (str) –
num_uids (int) –
- backward(scores, rewards)#
Runs a backward pass through the model. :param scores: Scores for each uids as output by the gating model. :type scores:
torch.FloatTensor
of shape(metagraph.n)
:param rewards: Rewards for each uids as output by the reward model. :type rewards:torch.FloatTensor
of shape(metagraph.n)
- Parameters:
scores (torch.FloatTensor) –
rewards (torch.FloatTensor) –
- forward(message)#
Runs a forward pass through the model. :param message: text message to be encoded. :type message:
str
- Returns:
Scores for each uids as output by the gating model.
- Return type:
scores (
torch.FloatTensor
of shape(network_size)
)- Parameters:
message (str) –
- resync(previous_metagraph, metagraph)#
Resync the gating model with the latest state of the network Args: previous_metagraph (:obj: bt.metagraph.Metagraph):
Previous state of metagraph before updated resync
- metagraph (:obj: bt.metagraph.Metagraph):
Latest state of the metagraph with updated uids and hotkeys
- Parameters:
previous_metagraph (bittensor.metagraph.Metagraph) –
metagraph (bittensor.metagraph.Metagraph) –
- class prompting.validators.gating.SentenceEmbedGatingModel(metagraph, config=None, model_name=None, num_uids=None)#
Bases:
BaseGatingModel
This class is a PyTorch module that encapsulates a custom version of a gating model based on sentence transformers.
- The backward method runs a backward pass through the model using the mean squared error between the normalized
scores and the normalized rewards as the loss function.
- The forward method runs a forward pass through the model, encoding the input message and generating scores
for each uid in the network. The scores are returned as a tensor.
- Parameters:
metagraph (bittensor.metagraph.Metagraph) –
config (bittensor.config) –
model_name (str) –
num_uids (int) –
- mean_pooling(model_output, attention_mask)#
Applies mean pooling to the token embeddings generated by the model. :param model_output: Embedding model output, where the first element contains token embeddings. :type model_output: torch.Tensor :param attention_mask: Attention mask to indicate valid tokens. :type attention_mask: torch.Tensor
- Returns:
Mean-pooled representation of the token embeddings.
- Return type:
Notes
The function calculates the mean-pooled representation using the attention mask for valid tokens.
Input_mask_expanded is created by expanding the attention mask to match the size of token embeddings.
The result is obtained by summing the element-wise multiplication of embeddings and input_mask_expanded, and dividing it by the sum of input_mask_expanded after clamping its values to a minimum of 1e-9.
- forward(message)#
Runs a forward pass through the model. :param message: text message to be encoded. :type message:
str
- Returns:
Scores for each uids as output by the gating model.
- Return type:
scores (
torch.FloatTensor
of shape(network_size)
)- Parameters:
message (str) –
- backward(scores, rewards)#
Runs a backward pass through the model. :param scores: Scores for each uids as output by the gating model. :type scores:
torch.FloatTensor
of shape(metagraph.n)
:param rewards: Rewards for each uids as output by the reward model. :type rewards:torch.FloatTensor
of shape(metagraph.n)
- Parameters:
scores (torch.FloatTensor) –
rewards (torch.FloatTensor) –
- resync(previous_metagraph, metagraph)#
Resync the gating model with the latest state of the network Args: previous_metagraph (:obj: bt.metagraph.Metagraph):
Previous state of metagraph before updated resync
- metagraph (:obj: bt.metagraph.Metagraph):
Latest state of the metagraph with updated uids and hotkeys
- Parameters:
previous_metagraph (bittensor.metagraph.Metagraph) –
metagraph (bittensor.metagraph.Metagraph) –