postprocess
Utils to load and process model_config.
Functions
Check if weight shape are valid with inference TP. |
|
Pad lm_head and embedding as multiples of 64 for AWQ quantization. |
|
Postprocesses the model configs with trained tensor parallel to target inference tensor parallel. |
|
Make all tensors in the model_config are on CPU, contiguous and own the memory. |
|
Update lm_head quantization config for TRT-LLM export. |
- check_weight_shape_valid(config, inference_tensor_parallel=1, training_tensor_parallel=1)
Check if weight shape are valid with inference TP.
This function is recurisve.
- pad_embedding_lm_head(model_config, padding_factor=64)
Pad lm_head and embedding as multiples of 64 for AWQ quantization.
- Parameters:
model_config (ModelConfig) –
padding_factor (int) –
- postprocess_model_config(model_config, inference_tensor_parallel=1, inference_pipeline_parallel=1, training_pipeline_parallel=1, workspace_path=None)
Postprocesses the model configs with trained tensor parallel to target inference tensor parallel.
If the training_pipeline_parallel > 1, the model configs across PP will be merged to one.
- Returns:
- The processed model config as a list.
- For the merging case:
The merged rank will return the merged model_config as an single item list. The other ranks will return an empty list as we no longer export them.
- For the split case:
The splitted model config list is returned.
- Parameters:
inference_tensor_parallel (int) –
inference_pipeline_parallel (int) –
training_pipeline_parallel (int) –
workspace_path (Path | str | None) –
- Return type:
List[ModelConfig]
- postprocess_tensors(model_config, force_cpu=True, force_contiguous=True, force_non_view=True)
Make all tensors in the model_config are on CPU, contiguous and own the memory.
- Parameters:
model_config (ModelConfig) –
force_cpu (bool) –
force_contiguous (bool) –
force_non_view (bool) –
- update_lm_head_quantization(config, lm_head, inference_tensor_parallel=1)
Update lm_head quantization config for TRT-LLM export.
- Parameters:
config (ModelConfig) –
lm_head (QuantLinear) –
inference_tensor_parallel (int) –