validation_utils
Utility functions for validating models and extracting hidden states and similarity metrics.
TODO: Consider moving this a separate module dedicated for scoring.
Functions
- Parameters:
args (DictConfig)
model (nn.Module | StitchedModule)
tokenizer (PreTrainedTokenizerBase)
output_dir (str | Path)
model_name (str)
extra_payload (dict[str, Any] | None)
- Return type:
list[Tensor | LowMemorySparseTensor]
- validate_model_with_teacher_similarity_metrics(args, model, tokenizer, target_hidden_states_per_batch, output_dir, model_name, extra_payload=None, calculate_full_score_ablations=False, val_dataloader=None)
- Parameters:
args (DictConfig)
model (nn.Module | StitchedModule)
tokenizer (PreTrainedTokenizerBase)
target_hidden_states_per_batch (list[Tensor])
output_dir (str | Path)
model_name (str)
extra_payload (dict[str, Any] | None)
calculate_full_score_ablations (bool)
- Return type:
None
- write_results(output_dir, result_name, args, payload)
- Parameters:
output_dir (str | Path)
result_name (str)
args (DictConfig)
payload (dict[str, Any])
- Return type:
None