TorchFort C API

These are all the types and functions available in the TorchFort C API.

General

Types

torchfort_datatype_t

enum torchfort_datatype_t

This enum defines the data types supported.

Values:

enumerator TORCHFORT_FLOAT
enumerator TORCHFORT_DOUBLE

torchfort_result_t

enum torchfort_result_t

This enum defines the possible values return values from TorchFort. Most functions in the TorchFort library will return one of these values to indicate if an operation has completed successfully or an error occured.

Values:

enumerator TORCHFORT_RESULT_SUCCESS

The operation completed successfully.

enumerator TORCHFORT_RESULT_INVALID_USAGE

A user error, typically an invalid argument.

enumerator TORCHFORT_RESULT_NOT_SUPPORTED

A user error, requesting an invalid or unsupported operation configuration.

enumerator TORCHFORT_RESULT_INTERNAL_ERROR

An internal library error, should be reported.

enumerator TORCHFORT_RESULT_CUDA_ERROR

An error occured in the CUDA Runtime.

enumerator TORCHFORT_RESULT_MPI_ERROR

An error occured in the MPI library.

enumerator TORCHFORT_RESULT_NCCL_ERROR

An error occured in the NCCL library.


Global Context Settings

These are global routines which affect the behavior of the libtorch backend. It is therefore recommended to call these functions before any other TorchFort calls are made.

torchfort_set_cudnn_benchmark

torchfort_result_t torchfort_set_cudnn_benchmark(const bool flag)

Utility function to enable/disable cuDNN runtime autotuning in PyTorch.

Parameters

flag[in] Boolean value to set the cuDNN benchmarking flag to.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.

Supervised Learning

Model Creation

torchfort_create_model

torchfort_result_t torchfort_create_model(const char *name, const char *config_fname)

Creates a model instance from a provided configuration file.

Parameters
  • name[in] A name to assign to the created model instance to use as a key for other TorchFort routines.

  • config_fname[in] The filesystem path to the user-defined model configuration file to use.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_create_distributed_model

torchfort_result_t torchfort_create_distributed_model(const char *name, const char *config_fname, MPI_Comm mpi_comm)

Creates a distributed data-parallel model from a provided configuration file.

Parameters
  • name[in] A name to assign to created model to use as a key for other TorchFort routines.

  • config_fname[in] The filesystem path to the user-defined model configuration file to use.

  • mpi_comm[in] MPI communicator to use to initialize NCCL communication library for data-parallel communication.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


Model Training/Inference

torchfort_train

torchfort_result_t torchfort_train(const char *name, void *input, size_t input_dim, int64_t *input_shape, void *label, size_t label_dim, int64_t *label_shape, void *loss_val, torchfort_datatype_t dtype, cudaStream_t stream)

Runs a training iteration of a model instance using provided input and label data.

Parameters
  • name[in] The name of model instance to use, as defined during model creation.

  • input[in] A pointer to a memory buffer containing input data.

  • input_dim[in] Rank of the input data.

  • input_shape[in] A pointer to an array specifying the shape of the input data. Length should be equal to the rank of the input data.

  • label[in] A pointer to a memory buffer containing label data.

  • label_dim[in] Rank of the label data.

  • label_shape[in] A pointer to an array specifying the shape of the label data. Length should be equal to the rank of the label data.

  • loss_val[out] A pointer to a memory location to write the loss value computed during the training iteration.

  • dtype[out] The TorchFort datatype to use for this operation.

  • stream[out] CUDA stream to enqueue the training operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_inference

torchfort_result_t torchfort_inference(const char *name, void *input, size_t input_dim, int64_t *input_shape, void *output, size_t output_dim, int64_t *output_shape, torchfort_datatype_t dtype, cudaStream_t stream)

Runs inference on a model using provided input data.

Parameters
  • name[in] The name of model instance to use, as defined during model creation.

  • input[in] A pointer to a memory buffer containing input data.

  • input_dim[in] Rank of the input data.

  • input_shape[in] A pointer to an array specifying the shape of the input data. Length should be equal to the rank of the input data.

  • output[inout] A pointer to a memory buffer to write output data.

  • output_dim[in] Rank of the output data.

  • output_shape[in] A pointer to an array specifying the shape of the output data. Length should be equal to the rank of the output data.

  • dtype[out] The TorchFort datatype to use for this operation.

  • stream[out] CUDA steram to enqueue the inference operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


Model Management

torchfort_save_model

torchfort_result_t torchfort_save_model(const char *name, const char *fname)

Saves a model to file.

Parameters
  • name[in] The name of model instance to save, as defined during model creation.

  • fname[in] The filename to save the model weights to.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_load_model

torchfort_result_t torchfort_load_model(const char *name, const char *fname)

Loads a model from a file.

Parameters
  • name[in] The name of model instance to load the model weights to, as defined during model creation.

  • fname[in] The filename to load the model from.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_save_checkpoint

torchfort_result_t torchfort_save_checkpoint(const char *name, const char *checkpoint_dir)

Saves a training checkpoint to a directory. In contrast to torchfort_save_model, this function saves additional state to restart training, like the optimizer states and learning rate schedule progress.

Parameters
  • name[in] The name of model instance to save, as defined during model creation.

  • checkpoint_dir[in] A writeable filesystem path to a directory to save the checkpoint data to.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_load_checkpoint

torchfort_result_t torchfort_load_checkpoint(const char *name, const char *checkpoint_dir, int64_t *step_train, int64_t *step_inference)

Loads a training checkpoint from a directory. In contrast to the torchfort_load_model, this function loads additional state to restart training, like the optimizer states and learning rate schedule progress.

Parameters
  • name[in] The name of model instance to load checkpoint data into, as defined during model creation.

  • checkpoint_dir[in] A readable filesystem path to a directory to load the checkpoint data from.

  • step_train[out] A pointer to an integer to write current training step for loaded checkpoint.

  • step_inference[out] A pointer to an integer to write current inference step for loaded checkpoint.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


Weights and Biases Logging

torchfort_wandb_log_int

torchfort_result_t torchfort_wandb_log_int(const char *name, const char *metric_name, int64_t step, int value)

Write an integer value to a Weights and Bias log. Use the _float and _double variants to write float and double values respectively.

Parameters
  • name[in] The name of model instance to associate this metric value with, as defined during model creation.

  • metric_name[in] Metric label.

  • step[in] Training/inference step to associate with metric value.

  • value[in] Metric value to log.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_wandb_log_float

torchfort_result_t torchfort_wandb_log_float(const char *name, const char *metric_name, int64_t step, float value)

torchfort_wandb_log_double

torchfort_result_t torchfort_wandb_log_double(const char *name, const char *metric_name, int64_t step, double value)

Reinforcement Learning

System Creation

Basic routines to create and register a reinforcement learning system in the internal registry. A (synchronous) data parallel distributed option is available.

torchfort_rl_off_policy_create_system

torchfort_result_t torchfort_rl_off_policy_create_system(const char *name, const char *config_fname)

Creates an off-policy reinforcement learning training system instance from a provided configuration file.

Parameters
  • name[in] A name to assign to the created training system instance to use as a key for other TorchFort routines.

  • config_fname[in] The filesystem path to the user-defined configuration file to use.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_create_distributed_system

torchfort_result_t torchfort_rl_off_policy_create_distributed_system(const char *name, const char *config_fname, MPI_Comm mpi_comm)

Creates a (synchronous) data-parallel off-policy reinforcement learning system instance from a provided configuration file.

Parameters
  • name[in] A name to assign to the created training system instance to use as a key for other TorchFort routines.

  • config_fname[in] The filesystem path to the user-defined configuration file to use.

  • mpi_comm[in] MPI communicator to use to initialize NCCL communication library for data-parallel communication.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


Training/Evaluation

These routines are used for training the reinforcement learning system or for steering the environment.

torchfort_rl_off_policy_train_step

torchfort_result_t torchfort_rl_off_policy_train_step(const char *name, float *p_loss_val, float *q_loss_val, cudaStream_t stream)

Runs a training iteration of an off-policy refinforcement learning instance and returns loss values for policy and value functions.

This routine samples a batch of specified size from the replay buffer according to the buffers sampling procedure and performs a train step using this sample. The details of the training procedure are abstracted away from the user and depend on the chosen system algorithm.

Parameters
  • name[in] The name of system instance to use, as defined during system creation.

  • p_loss_val[out] A pointer to a memory location to write the policy loss value computed during the training iteration.

  • q_loss_val[out] A pointer to a memory location to write the critic loss value computed during the training iteration. If the system uses multiple critics, the average across all critics is returned.

  • stream[out] CUDA stream to enqueue the training operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_predict_explore

torchfort_result_t torchfort_rl_off_policy_predict_explore(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, torchfort_datatype_t dtype, cudaStream_t stream)

Suggests an action based on the current state of the system and adds noise as specified by the coprresponding reinforcement learning system.

Depending on the reinforcement learning algorithm used, the prediction is performed by the main network (not the target network). In contrast to torchfort_rl_off_policy_predict, this routine adds noise and thus is called explorative. The kind of noise is specified during system creation.

Parameters
  • name[in] The name of system instance to use, as defined during system creation.

  • state[in] A pointer to a memory buffer containing state data.

  • state_dim[in] Rank of the state data.

  • state_shape[in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.

  • action[inout] A pointer to a memory buffer to write action data.

  • action_dim[in] Rank of the action data.

  • action_shape[in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.

  • dtype[out] The TorchFort datatype to use for this operation.

  • stream[out] CUDA stream to enqueue the action prediction operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_predict

torchfort_result_t torchfort_rl_off_policy_predict(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, torchfort_datatype_t dtype, cudaStream_t stream)

Suggests an action based on the current state of the system.

Depending on the algorithm used, the prediction is performed by the target network. In contrast to torchfort_rl_off_policy_predict_explore, this routine does not add noise, which means it is exploitative.

Parameters
  • name[in] The name of system instance to use, as defined during system creation.

  • state[in] A pointer to a memory buffer containing state data.

  • state_dim[in] Rank of the state data.

  • state_shape[in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.

  • action[inout] A pointer to a memory buffer to write action data.

  • action_dim[in] Rank of the action data.

  • action_shape[in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.

  • dtype[out] The TorchFort datatype to use for this operation.

  • stream[out] CUDA stream to enqueue the action prediction operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_evaluate

torchfort_result_t torchfort_rl_off_policy_evaluate(const char *name, void *state, size_t state_dim, int64_t *state_shape, void *action, size_t action_dim, int64_t *action_shape, void *reward, size_t reward_dim, int64_t *reward_shape, torchfort_datatype_t dtype, cudaStream_t stream)

Predicts the future reward based on the current state and selected action.

Depending on the learning algorithm, the routine queries the target critic networks for this. The routine averages the predictions over all critics.

Parameters
  • name[in] The name of system instance to use, as defined during system creation.

  • state[in] A pointer to a memory buffer containing state data.

  • state_dim[in] Rank of the state data.

  • state_shape[in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state data.

  • action[in] A pointer to a memory buffer containing action data.

  • action_dim[in] Rank of the action data.

  • action_shape[in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.

  • reward[inout] A pointer to a memory buffer to write reward data.

  • reward_dim[in] Rank of the reward data.

  • reward_shape[in] A pointer to an array specifying the shape of the reward data. Length should be equal to the rank of the reward data.

  • dtype[out] The TorchFort datatype to use for this operation.

  • stream[out] CUDA stream to enqueue the action prediction operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


System Management

The purpose of these routines is to manage the reinforcement learning systems internal data. It allows the user to add tuples to the replay buffer and query the system for readiness. Additionally, save and restore functionality is also provided.

torchfort_rl_off_policy_update_replay_buffer

torchfort_result_t torchfort_rl_off_policy_update_replay_buffer(const char *name, void *state_old, void *state_new, size_t state_dim, int64_t *state_shape, void *action_old, size_t action_dim, int64_t *action_shape, const void *reward, bool final_state, torchfort_datatype_t dtype, cudaStream_t stream)

Adds a new \((s, a, s', r, d)\) tuple to the replay buffer.

Here \(s\) (state_old) is the state for which action \(a\) (action_old) was taken, leading to \(s'\) (state_new) and receiving reward \(r\) (reward). The terminal state flag \(d\) (final_state) specifies whether \(s'\) is the final state in the episode.

Parameters
  • name[in] The name of system instance to use, as defined during system creation.

  • state_old[in] A pointer to a memory buffer containing previous state data.

  • state_new[in] A pointer to a memory buffer containing new state data.

  • state_dim[in] Rank of the state data.

  • state_shape[in] A pointer to an array specifying the shape of the state data. Length should be equal to the rank of the state_old and state_new data.

  • action_old[in] A pointer to a memory buffer containing action data.

  • action_dim[in] Rank of the action data.

  • action_shape[in] A pointer to an array specifying the shape of the action data. Length should be equal to the rank of the action data.

  • reward[in] A pointer to a memory buffer with reward data.

  • final_state[in] A flag indicating whether \pstate_new is the final state in the current episode (set to true if it is the final state, otherwise false).

  • dtype[out] The TorchFort datatype to use for this operation.

  • stream[out] CUDA stream to enqueue the action prediction operations.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_is_ready

torchfort_result_t torchfort_rl_off_policy_is_ready(const char *name, bool &ready)

Queries a reinforcement learning system for rediness to start training.

A user should call this method before starting training to make sure the reinforcement learning system is ready. This method ensures that the replay buffer is filled sufficiently with exploration data as specified during system creation.

Parameters
  • name[in] The name of a system instance to restore the data for, as defined during system creation

  • ready[out] A flag indicating whether the system is ready to train (true means it is ready to train)

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_save_checkpoint

torchfort_result_t torchfort_rl_off_policy_save_checkpoint(const char *name, const char *checkpoint_dir)

Saves a reinforcement learning training checkpoint to a directory.

This method saves all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also saves the state of the replay buffer, to allow for smooth restarts of reinforcement learning training processes. This function should be used in conjunction with torchfort_rl_off_policy_load_checkpoint.

Parameters
  • name[in] The name of a system instance to save, as defined during system creation.

  • checkpoint_dir[in] A filesystem path to a directory to save the checkpoint data to.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_load_checkpoint

torchfort_result_t torchfort_rl_off_policy_load_checkpoint(const char *name, const char *checkpoint_dir)

Restores a reinforcement learning system from a checkpoint.

This method restores all models (policies, critics, target models if available) together with their corresponding optimizer and LR scheduler states. It also fully restores the state of the replay buffer, but not the current RNG seed. This function should be used in conjunction with torchfort_rl_off_policy_save_checkpoint.

Parameters
  • name[in] The name of a system instance to restore the data for, as defined during system creation.

  • checkpoint_dir[in] A filesystem path to a directory which contains the checkpoint data to load.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


Weights and Biases Logging

The reinforcement learning system performs logging for all involved networks automatically during training. The following routines are provided for additional logging of system relevant quantities, such as e.g. the accumulated reward.

torchfort_rl_off_policy_wandb_log_int

torchfort_result_t torchfort_rl_off_policy_wandb_log_int(const char *name, const char *metric_name, int64_t step, int value)

Write an integer value to a Weights and Bias log using the system logging tag. *_float and *_double variants to write float and double values respectively.

Parameters
  • name[in] The name of system instance to associate this metric value with, as defined during system creation.

  • metric_name[in] Metric label.

  • step[in] Training/inference step to associate with metric value.

  • value[in] Metric value to log.

Returns

TORCHFORT_RESULT_SUCCESS on success or error code on failure.


torchfort_rl_off_policy_wandb_log_float

torchfort_result_t torchfort_rl_off_policy_wandb_log_float(const char *name, const char *metric_name, int64_t step, float value)

torchfort_rl_off_policy_wandb_log_double

torchfort_result_t torchfort_rl_off_policy_wandb_log_double(const char *name, const char *metric_name, int64_t step, double value)