Infer amplify
get_converted_hf_checkpoint(hf_model_name, results_path)
Convert a HuggingFace model to a NeMo checkpoint and return the path.
Source code in bionemo/amplify/infer_amplify.py
152 153 154 155 156 157 158 159 |
|
main(data_path, hf_model_name=None, initial_ckpt_path=None, results_path=Path('results'), seq_length=1024, include_hiddens=False, include_embeddings=False, include_logits=False, include_input_ids=False, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, prediction_interval='epoch')
Runs inference on a BioNeMo AMPLIFY model using PyTorch Lightning.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_path
|
Path
|
Path to the input data CSV file |
required |
hf_model_name
|
str | None
|
HuggingFace model name/path to load |
None
|
initial_ckpt_path
|
str | None
|
Path to the initial checkpoint to load. Only one of hf_model_name or initial_ckpt_path should be provided. |
None
|
results_path
|
Path
|
Path to save inference results |
Path('results')
|
seq_length
|
int
|
Mix/max sequence length for padding |
1024
|
include_hiddens
|
bool
|
Whether to include hidden states in output |
False
|
include_embeddings
|
bool
|
Whether to include embeddings in output |
False
|
include_logits
|
bool
|
Whether to include token logits in output |
False
|
include_input_ids
|
bool
|
Whether to include input IDs in output |
False
|
micro_batch_size
|
int
|
Micro batch size for inference |
64
|
precision
|
str
|
Precision type for inference |
'bf16-mixed'
|
tensor_model_parallel_size
|
int
|
Tensor model parallel size |
1
|
pipeline_model_parallel_size
|
int
|
Pipeline model parallel size |
1
|
devices
|
int
|
Number of devices to use |
1
|
num_nodes
|
int
|
Number of nodes for distributed inference |
1
|
prediction_interval
|
str
|
Intervals to write predictions to disk |
'epoch'
|
Source code in bionemo/amplify/infer_amplify.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|