validation_utils

Utility functions for validating models and extracting hidden states and similarity metrics.

TODO: Consider moving this a separate module dedicated for scoring.

Functions

validate_model_and_extract_hidden_states

validate_model_with_teacher_similarity_metrics

write_results

validate_model_and_extract_hidden_states(args, model, tokenizer, output_dir, model_name, extra_payload=None, val_dataloader=None)
Parameters:
  • args (DictConfig)

  • model (nn.Module | StitchedModule)

  • tokenizer (PreTrainedTokenizerBase)

  • output_dir (str | Path)

  • model_name (str)

  • extra_payload (dict[str, Any] | None)

Return type:

list[Tensor | LowMemorySparseTensor]

validate_model_with_teacher_similarity_metrics(args, model, tokenizer, target_hidden_states_per_batch, output_dir, model_name, extra_payload=None, calculate_full_score_ablations=False, val_dataloader=None)
Parameters:
  • args (DictConfig)

  • model (nn.Module | StitchedModule)

  • tokenizer (PreTrainedTokenizerBase)

  • target_hidden_states_per_batch (list[Tensor])

  • output_dir (str | Path)

  • model_name (str)

  • extra_payload (dict[str, Any] | None)

  • calculate_full_score_ablations (bool)

Return type:

None

write_results(output_dir, result_name, args, payload)
Parameters:
  • output_dir (str | Path)

  • result_name (str)

  • args (DictConfig)

  • payload (dict[str, Any])

Return type:

None