TorchFort C API¶
These are all the types and functions available in the TorchFort C API.
General¶
Types¶
torchfort_datatype_t¶
-
enum torchfort_datatype_t¶
This enum defines the data types supported.
Values:
-
enumerator TORCHFORT_FLOAT¶
-
enumerator TORCHFORT_DOUBLE¶
-
enumerator TORCHFORT_FLOAT¶
torchfort_result_t¶
-
enum torchfort_result_t¶
This enum defines the possible values return values from TorchFort. Most functions in the TorchFort library will return one of these values to indicate if an operation has completed successfully or an error occured.
Values:
-
enumerator TORCHFORT_RESULT_SUCCESS¶
The operation completed successfully.
-
enumerator TORCHFORT_RESULT_INVALID_USAGE¶
A user error, typically an invalid argument.
-
enumerator TORCHFORT_RESULT_NOT_SUPPORTED¶
A user error, requesting an invalid or unsupported operation configuration.
-
enumerator TORCHFORT_RESULT_INTERNAL_ERROR¶
An internal library error, should be reported.
-
enumerator TORCHFORT_RESULT_CUDA_ERROR¶
An error occured in the CUDA Runtime.
-
enumerator TORCHFORT_RESULT_MPI_ERROR¶
An error occured in the MPI library.
-
enumerator TORCHFORT_RESULT_NCCL_ERROR¶
An error occured in the NCCL library.
-
enumerator TORCHFORT_RESULT_SUCCESS¶
Global Context Settings¶
These are global routines which affect the behavior of the libtorch backend. It is therefore recommended to call these functions before any other TorchFort calls are made.
torchfort_set_cudnn_benchmark¶
-
torchfort_result_t torchfort_set_cudnn_benchmark(const bool flag)¶
Utility function to enable/disable cuDNN runtime autotuning in PyTorch.
- Parameters
flag – [in] Boolean value to set the cuDNN benchmarking flag to.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Supervised Learning¶
Model Creation¶
torchfort_create_model¶
-
torchfort_result_t torchfort_create_model(const char *name, const char *config_fname, int device)¶
Creates a model instance from a provided configuration file.
- Parameters
name – [in] A name to assign to the created model instance to use as a key for other TorchFort routines.
config_fname – [in] The filesystem path to the user-defined model configuration file to use.
device – [in] Which device to place and run the model on. For TORCHFORT_DEVICE_CPU (-1), model will be placed on CPU. For values >= 0, model will be placed on GPU with index corresponding to value.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_create_distributed_model¶
-
torchfort_result_t torchfort_create_distributed_model(const char *name, const char *config_fname, MPI_Comm mpi_comm, int device)¶
Creates a distributed data-parallel model from a provided configuration file.
- Parameters
name – [in] A name to assign to created model to use as a key for other TorchFort routines.
config_fname – [in] The filesystem path to the user-defined model configuration file to use.
mpi_comm – [in] MPI communicator to use to initialize NCCL communication library for data-parallel communication.
device – [in] Which device to place and run the model on. For TORCHFORT_DEVICE_CPU (-1), model will be placed on CPU. For values >= 0, model will be placed on GPU with index corresponding to value.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Model Training/Inference¶
torchfort_train¶
-
torchfort_result_t torchfort_train(const char *name, void *input, size_t input_dim, int64_t *input_shape, void *label, size_t label_dim, int64_t *label_shape, void *loss_val, torchfort_datatype_t dtype, cudaStream_t stream)¶
Runs a training iteration of a model instance using provided input and label data.
- Parameters
name – [in] The name of model instance to use, as defined during model creation.
input – [in] A pointer to a memory buffer containing input data.
input_dim – [in] Rank of the input data.
input_shape – [in] A pointer to an array specifying the shape of the input data. Length should be equal to the rank of the input data.
label – [in] A pointer to a memory buffer containing label data.
label_dim – [in] Rank of the label data.
label_shape – [in] A pointer to an array specifying the shape of the label data. Length should be equal to the rank of the label data.
loss_val – [out] A pointer to a memory location to write the loss value computed during the training iteration.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_inference¶
-
torchfort_result_t torchfort_inference(const char *name, void *input, size_t input_dim, int64_t *input_shape, void *output, size_t output_dim, int64_t *output_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Runs inference on a model using provided input data.
- Parameters
name – [in] The name of model instance to use, as defined during model creation.
input – [in] A pointer to a memory buffer containing input data.
input_dim – [in] Rank of the input data.
input_shape – [in] A pointer to an array specifying the shape of the input data. Length should be equal to the rank of the input data.
output – [inout] A pointer to a memory buffer to write output data.
output_dim – [in] Rank of the output data.
output_shape – [in] A pointer to an array specifying the shape of the output data. Length should be equal to the rank of the output data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Model Management¶
torchfort_save_model¶
-
torchfort_result_t torchfort_save_model(const char *name, const char *fname)¶
Saves a model to file.
- Parameters
name – [in] The name of model instance to save, as defined during model creation.
fname – [in] The filename to save the model weights to.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_load_model¶
-
torchfort_result_t torchfort_load_model(const char *name, const char *fname)¶
Loads a model from a file.
- Parameters
name – [in] The name of model instance to load the model weights to, as defined during model creation.
fname – [in] The filename to load the model from.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_save_checkpoint¶
-
torchfort_result_t torchfort_save_checkpoint(const char *name, const char *checkpoint_dir)¶
Saves a training checkpoint to a directory. In contrast to
torchfort_save_model
, this function saves additional state to restart training, like the optimizer states and learning rate schedule progress.- Parameters
name – [in] The name of model instance to save, as defined during model creation.
checkpoint_dir – [in] A writeable filesystem path to a directory to save the checkpoint data to.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_load_checkpoint¶
-
torchfort_result_t torchfort_load_checkpoint(const char *name, const char *checkpoint_dir, int64_t *step_train, int64_t *step_inference)¶
Loads a training checkpoint from a directory. In contrast to the
torchfort_load_model
, this function loads additional state to restart training, like the optimizer states and learning rate schedule progress.- Parameters
name – [in] The name of model instance to load checkpoint data into, as defined during model creation.
checkpoint_dir – [in] A readable filesystem path to a directory to load the checkpoint data from.
step_train – [out] A pointer to an integer to write current training step for loaded checkpoint.
step_inference – [out] A pointer to an integer to write current inference step for loaded checkpoint.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Weights and Biases Logging¶
torchfort_wandb_log_int¶
-
torchfort_result_t torchfort_wandb_log_int(const char *name, const char *metric_name, int64_t step, int value)¶
Write an integer value to a Weights and Bias log. Use the
_float
and_double
variants to writefloat
anddouble
values respectively.- Parameters
name – [in] The name of model instance to associate this metric value with, as defined during model creation.
metric_name – [in] Metric label.
step – [in] Training/inference step to associate with metric value.
value – [in] Metric value to log.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_wandb_log_float¶
-
torchfort_result_t torchfort_wandb_log_float(const char *name, const char *metric_name, int64_t step, float value)¶
torchfort_wandb_log_double¶
-
torchfort_result_t torchfort_wandb_log_double(const char *name, const char *metric_name, int64_t step, double value)¶
Reinforcement Learning¶
Similar to other reinforcement learning frameworks such as Spinning Up from OpenAI or Stable Baselines, we distinguish between on-policy and off-policy algorithms since those two types require different APIs.
Off-Policy Algorithms¶
System Creation¶
Basic routines to create and register a reinforcement learning system in the internal registry. A (synchronous) data parallel distributed option is available.
torchfort_rl_off_policy_create_system¶
-
torchfort_result_t torchfort_rl_off_policy_create_system(const char *name, const char *config_fname, int model_device, int rb_device)¶
Creates an off-policy reinforcement learning training system instance from a provided configuration file.
- Parameters
name – [in] A name to assign to the created training system instance to use as a key for other TorchFort routines.
config_fname – [in] The filesystem path to the user-defined configuration file to use.
model_device – [in] Which device to place and run the model on. For TORCHFORT_DEVICE_CPU (-1), model will be placed on CPU. For values >= 0, model will be placed on GPU with index corresponding to value.
rb_device – [in] Which device to place the replay buffer on. For TORCHFORT_DEVICE_CPU (-1), the replay buffer will be placed on CPU. For values >= 0, the replay buffer will be placed on GPU with index corresponding to value.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_create_distributed_system¶
-
torchfort_result_t torchfort_rl_off_policy_create_distributed_system(const char *name, const char *config_fname, MPI_Comm mpi_comm, int model_device, int rb_device)¶
Creates a (synchronous) data-parallel off-policy reinforcement learning system instance from a provided configuration file.
- Parameters
name – [in] A name to assign to the created training system instance to use as a key for other TorchFort routines.
config_fname – [in] The filesystem path to the user-defined configuration file to use.
mpi_comm – [in] MPI communicator to use to initialize NCCL communication library for data-parallel communication.
model_device – [in] Which device to place and run the model on. For TORCHFORT_DEVICE_CPU (-1), model will be placed on CPU. For values >= 0, model will be placed on GPU with index corresponding to value.
rb_device – [in] Which device to place the replay buffer on. For TORCHFORT_DEVICE_CPU (-1), the replay buffer will be placed on CPU. For values >= 0, the replay buffer will be placed on GPU with index corresponding to value.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Training/Evaluation¶
These routines are used for training the reinforcement learning system or for steering the environment.
torchfort_rl_off_policy_train_step¶
-
torchfort_result_t torchfort_rl_off_policy_train_step(const char *name, float *p_loss_val, float *q_loss_val, cudaStream_t stream)¶
Runs a training iteration of an off-policy refinforcement learning instance and returns loss values for policy and value functions.
This routine samples a batch of specified size from the replay buffer according to the buffers sampling procedure and performs a train step using this sample. The details of the training procedure are abstracted away from the user and depend on the chosen system algorithm.
- Parameters
name – [in] The name of system instance to use, as defined during system creation.
p_loss_val – [out] A pointer to a memory location to write the policy loss value computed during the training iteration.
q_loss_val – [out] A pointer to a memory location to write the critic loss value computed during the training iteration. If the system uses multiple critics, the average across all critics is returned.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_predict_explore¶
-
torchfort_result_t torchfort_rl_off_policy_predict_explore(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Suggests an action based on the current state of the system and adds noise as specified by the coprresponding reinforcement learning system.
Depending on the reinforcement learning algorithm used, the prediction is performed by the main network (not the target network). In contrast to
torchfort_rl_off_policy_predict
, this routine adds noise and thus is called explorative. The kind of noise is specified during system creation.- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.
action – [inout] A pointer to a memory buffer to write action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_predict¶
-
torchfort_result_t torchfort_rl_off_policy_predict(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Suggests an action based on the current state of the system.
Depending on the algorithm used, the prediction is performed by the target network. In contrast to
torchfort_rl_off_policy_predict_explore
, this routine does not add noise, which means it is exploitative.- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.
action – [inout] A pointer to a memory buffer to write action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_evaluate¶
-
torchfort_result_t torchfort_rl_off_policy_evaluate(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, void *reward, size_t reward_dim, int64_t *reward_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Predicts the future reward based on the current state and selected action.
Depending on the learning algorithm, the routine queries the target critic networks for this. The routine averages the predictions over all critics.
- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.
action – [in] A pointer to a memory buffer containing action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
reward – [inout] A pointer to a memory buffer to write reward data.
reward_dim – [in] Rank of the reward data.
reward_shape – [in] A pointer to an array specifying the shape of the reward data. Length should be equal to the rank of the reward data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
System Management¶
The purpose of these routines is to manage the reinforcement learning systems internal data. It allows the user to add tuples to the replay buffer and query the system for readiness. Additionally, save and restore functionality is also provided.
torchfort_rl_off_policy_update_replay_buffer¶
-
torchfort_result_t torchfort_rl_off_policy_update_replay_buffer(const char *name, void *state_old, void *state_new, size_t state_dim, int64_t *state_shape, void *action_old, size_t action_dim, int64_t *action_shape, const void *reward, bool final_state, torchfort_datatype_t dtype, cudaStream_t stream)¶
Adds a new \((s, a, s', r, d)\) tuple to the replay buffer.
Here \(s\) (
state_old
) is the state for which action \(a\) (action_old
) was taken, leading to \(s'\) (state_new
) and receiving reward \(r\) (reward
). The terminal state flag \(d\) (final_state
) specifies whether \(s'\) is the final state in the episode.- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state_old – [in] A pointer to a memory buffer containing previous state data.
state_new – [in] A pointer to a memory buffer containing new state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the
state_old
andstate_new
data.action_old – [in] A pointer to a memory buffer containing action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
reward – [in] A pointer to a memory buffer with reward data.
final_state – [in] A flag indicating whether
state_new
is the final state in the current episode (set totrue
if it is the final state, otherwisefalse
).dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the operation. This argument is ignored if the model is on the CPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_is_ready¶
-
torchfort_result_t torchfort_rl_off_policy_is_ready(const char *name, bool &ready)¶
Queries a reinforcement learning system for rediness to start training.
A user should call this method before starting training to make sure the reinforcement learning system is ready. This method ensures that the replay buffer is filled sufficiently with exploration data as specified during system creation.
- Parameters
name – [in] The name of a system instance to restore the data for, as defined during system creation
ready – [out] A flag indicating whether the system is ready to train (
true
means it is ready to train)
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_save_checkpoint¶
-
torchfort_result_t torchfort_rl_off_policy_save_checkpoint(const char *name, const char *checkpoint_dir)¶
Saves a reinforcement learning training checkpoint to a directory.
This method saves all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also saves the state of the replay buffer, to allow for smooth restarts of reinforcement learning training processes. This function should be used in conjunction with
torchfort_rl_off_policy_load_checkpoint
.- Parameters
name – [in] The name of a system instance to save, as defined during system creation.
checkpoint_dir – [in] A filesystem path to a directory to save the checkpoint data to.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_load_checkpoint¶
-
torchfort_result_t torchfort_rl_off_policy_load_checkpoint(const char *name, const char *checkpoint_dir)¶
Restores a reinforcement learning system from a checkpoint.
This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also fully restores the state of the replay buffer, but not the current RNG seed. This function should be used in conjunction with
torchfort_rl_off_policy_save_checkpoint
.- Parameters
name – [in] The name of a system instance to restore the data for, as defined during system creation.
checkpoint_dir – [in] A filesystem path to a directory which contains the checkpoint data to load.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Weights and Biases Logging¶
The reinforcement learning system performs logging for all involved networks automatically during training. The following routines are provided for additional logging of system relevant quantities, such as e.g. the accumulated reward.
torchfort_rl_off_policy_wandb_log_int¶
-
torchfort_result_t torchfort_rl_off_policy_wandb_log_int(const char *name, const char *metric_name, int64_t step, int value)¶
Write an integer value to a Weights and Bias log using the system logging tag.
*_float
and*_double
variants to writefloat
anddouble
values respectively.- Parameters
name – [in] The name of system instance to associate this metric value with, as defined during system creation.
metric_name – [in] Metric label.
step – [in] Training/inference step to associate with metric value.
value – [in] Metric value to log.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_off_policy_wandb_log_float¶
-
torchfort_result_t torchfort_rl_off_policy_wandb_log_float(const char *name, const char *metric_name, int64_t step, float value)¶
torchfort_rl_off_policy_wandb_log_double¶
-
torchfort_result_t torchfort_rl_off_policy_wandb_log_double(const char *name, const char *metric_name, int64_t step, double value)¶
On-Policy Algorithms¶
System Creation¶
Basic routines to create and register a reinforcement learning system in the internal registry. A (synchronous) data parallel distributed option is available.
torchfort_rl_on_policy_create_system¶
-
torchfort_result_t torchfort_rl_on_policy_create_system(const char *name, const char *config_fname, int model_device, int rb_device)¶
Creates an on-policy reinforcement learning training system instance from a provided configuration file.
- Parameters
name – [in] A name to assign to the created training system instance to use as a key for other TorchFort routines.
config_fname – [in] The filesystem path to the user-defined configuration file to use.
model_device – [in] Which device type to place and run the model on. For a value >= 0, the model will be placed on corresponding GPU.
rb_device – [in] Which device type to place the replay buffer on. For a value >= 0, the buffer will be placed on the corresponding GPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_create_distributed_system¶
-
torchfort_result_t torchfort_rl_on_policy_create_distributed_system(const char *name, const char *config_fname, MPI_Comm mpi_comm, int model_device, int rb_device)¶
Creates a (synchronous) data-parallel on-policy reinforcement learning system instance from a provided configuration file.
- Parameters
name – [in] A name to assign to the created training system instance to use as a key for other TorchFort routines.
config_fname – [in] The filesystem path to the user-defined configuration file to use.
mpi_comm – [in] MPI communicator to use to initialize NCCL communication library for data-parallel communication.
model_device – [in] Which device type to place and run the model on. For a value >= 0, the model will be placed on corresponding GPU.
rb_device – [in] Which device type to place the replay buffer on. For a value >= 0, the buffer will be placed on the corresponding GPU.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Training/Evaluation¶
These routines are used for training the reinforcement learning system or for steering the environment.
torchfort_rl_on_policy_train_step¶
-
torchfort_result_t torchfort_rl_on_policy_train_step(const char *name, float *p_loss_val, float *q_loss_val, cudaStream_t stream)¶
Runs a training iteration of an on-policy refinforcement learning instance and returns loss values for policy and value functions.
This routine samples a batch of specified size from the replay buffer according to the buffers sampling procedure and performs a train step using this sample. The details of the training procedure are abstracted away from the user and depend on the chosen system algorithm.
- Parameters
name – [in] The name of system instance to use, as defined during system creation.
p_loss_val – [out] A pointer to a memory location to write the policy loss value computed during the training iteration.
q_loss_val – [out] A pointer to a memory location to write the critic loss value computed during the training iteration. If the system uses multiple critics, the average across all critics is returned.
stream – [out] CUDA stream to enqueue the training operations.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_predict_explore¶
-
torchfort_result_t torchfort_rl_on_policy_predict_explore(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Suggests an action based on the current state of the system and adds noise as specified by the coprresponding reinforcement learning system.
Depending on the reinforcement learning algorithm used, the prediction is performed by the main network (not the target network). In contrast to
torchfort_rl_on_policy_predict
, this routine adds noise and thus is called explorative. The kind of noise is specified during system creation.- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.
action – [inout] A pointer to a memory buffer to write action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the action prediction operations.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_predict¶
-
torchfort_result_t torchfort_rl_on_policy_predict(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Suggests an action based on the current state of the system.
Depending on the algorithm used, the prediction is performed by the target network. In contrast to
torchfort_rl_on_policy_predict_explore
, this routine does not add noise, which means it is exploitative.- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.
action – [inout] A pointer to a memory buffer to write action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the action prediction operations.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_evaluate¶
-
torchfort_result_t torchfort_rl_on_policy_evaluate(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, void *reward, size_t reward_dim, int64_t *reward_shape, torchfort_datatype_t dtype, cudaStream_t stream)¶
Predicts the future reward based on the current state and selected action.
Depending on the learning algorithm, the routine queries the target critic networks for this. The routine averages the predictions over all critics.
- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.
action – [in] A pointer to a memory buffer containing action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
reward – [inout] A pointer to a memory buffer to write reward data.
reward_dim – [in] Rank of the reward data.
reward_shape – [in] A pointer to an array specifying the shape of the reward data. Length should be equal to the rank of the reward data.
dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the action prediction operations.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
System Management¶
The purpose of these routines is to manage the reinforcement learning systems internal data. It allows the user to add tuples to the replay buffer and query the system for readiness. Additionally, save and restore functionality is also provided.
torchfort_rl_on_policy_update_rollout_buffer¶
-
torchfort_result_t torchfort_rl_on_policy_update_rollout_buffer(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, const void *reward, bool final_state, torchfort_datatype_t dtype, cudaStream_t stream)¶
Adds a new \((s, a, r, d)\) tuple to the rollout buffer.
Here \(s\) (
state
) is the state for which action \(a\) (action
) was taken, and receiving reward \(r\) (reward
). The terminal state flag \(d\) (terminal_state
) specifies whether \(s\) is the final state of the episode. Note that value estimates \(q\) as well was log-probabilities are also stored but the user does not need to pass those manually, those values are computed internally from the current policy and stored with the other values.- Parameters
name – [in] The name of system instance to use, as defined during system creation.
state – [in] A pointer to a memory buffer containing state data.
state_dim – [in] Rank of the state data.
state_shape – [in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the
state
data.action – [in] A pointer to a memory buffer containing action data.
action_dim – [in] Rank of the action data.
action_shape – [in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.
reward – [in] A pointer to a memory buffer with reward data.
final_state – [in] A flag indicating whether the state after
state
is the final state in the episode (set totrue
if this is true, otherwisefalse
).dtype – [out] The TorchFort datatype to use for this operation.
stream – [out] CUDA stream to enqueue the action prediction operations.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_is_ready¶
-
torchfort_result_t torchfort_rl_on_policy_is_ready(const char *name, bool &ready)¶
Queries a reinforcement learning system for rediness to start training.
A user should call this method before starting training to make sure the reinforcement learning system is ready. This method ensures that the rollout buffer is filled sufficiently with exploration data as specified during system creation. It also checks if the rollout buffer was properly finalized, e.g. all advantages were computed.
- Parameters
name – [in] The name of a system instance to restore the data for, as defined during system creation
ready – [out] A flag indicating whether the system is ready to train (
true
means it is ready to train)
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_save_checkpoint¶
-
torchfort_result_t torchfort_rl_on_policy_save_checkpoint(const char *name, const char *checkpoint_dir)¶
Saves a reinforcement learning training checkpoint to a directory.
This method saves all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also saves the state of the replay buffer, to allow for smooth restarts of reinforcement learning training processes. This function should be used in conjunction with
torchfort_rl_on_policy_load_checkpoint
.- Parameters
name – [in] The name of a system instance to save, as defined during system creation.
checkpoint_dir – [in] A filesystem path to a directory to save the checkpoint data to.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_load_checkpoint¶
-
torchfort_result_t torchfort_rl_on_policy_load_checkpoint(const char *name, const char *checkpoint_dir)¶
Restores a reinforcement learning system from a checkpoint.
This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also fully restores the state of the rollout buffer, but not the current RNG seed. This function should be used in conjunction with
torchfort_rl_on_policy_save_checkpoint
.- Parameters
name – [in] The name of a system instance to restore the data for, as defined during system creation.
checkpoint_dir – [in] A filesystem path to a directory which contains the checkpoint data to load.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
Weights and Biases Logging¶
The reinforcement learning system performs logging for all involved networks automatically during training. The following routines are provided for additional logging of system relevant quantities, such as e.g. the accumulated reward.
torchfort_rl_on_policy_wandb_log_int¶
-
torchfort_result_t torchfort_rl_on_policy_wandb_log_int(const char *name, const char *metric_name, int64_t step, int value)¶
Write an integer value to a Weights and Bias log using the system logging tag.
*_float
and*_double
variants to writefloat
anddouble
values respectively.- Parameters
name – [in] The name of system instance to associate this metric value with, as defined during system creation.
metric_name – [in] Metric label.
step – [in] Training/inference step to associate with metric value.
value – [in] Metric value to log.
- Returns
TORCHFORT_RESULT_SUCCESS
on success or error code on failure.
torchfort_rl_on_policy_wandb_log_float¶
-
torchfort_result_t torchfort_rl_on_policy_wandb_log_float(const char *name, const char *metric_name, int64_t step, float value)¶
torchfort_rl_on_policy_wandb_log_double¶
-
torchfort_result_t torchfort_rl_on_policy_wandb_log_double(const char *name, const char *metric_name, int64_t step, double value)¶