Skip to content

Tensorboard

verify_tensorboard_logs(tb_log_dir, expected_metrics, min_steps=1)

Verify that TensorBoard logs exist and contain expected metrics.

Parameters:

Name Type Description Default
tb_log_dir Path

Path to the TensorBoard log directory

required
expected_metrics list[str]

List of metric names expected in the logs

required
min_steps int

Minimum number of steps expected in the logs

1

Returns:

Type Description
Optional[str]

None if verification succeeds, error message string if it fails

Source code in bionemo/testing/tensorboard.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def verify_tensorboard_logs(tb_log_dir: Path, expected_metrics: list[str], min_steps: int = 1) -> Optional[str]:
    """Verify that TensorBoard logs exist and contain expected metrics.

    Args:
        tb_log_dir: Path to the TensorBoard log directory
        expected_metrics: List of metric names expected in the logs
        min_steps: Minimum number of steps expected in the logs

    Returns:
        None if verification succeeds, error message string if it fails
    """
    # Find event files in the log directory
    event_files = list(tb_log_dir.glob("events.out.tfevents.*"))
    if len(event_files) == 0:
        return f"No TensorBoard event files found in {tb_log_dir}"

    # Load the event file
    event_acc = EventAccumulator(str(tb_log_dir))
    event_acc.Reload()

    # Get available scalar tags
    scalar_tags = event_acc.Tags()["scalars"]

    # Check that expected metrics are present
    for metric in expected_metrics:
        # Check if metric exists in any form (might have prefixes like "train/" or suffixes)
        metric_found = any(metric in tag for tag in scalar_tags)
        if not metric_found:
            return f"Expected metric '{metric}' not found in TensorBoard logs. Available tags: {scalar_tags}"

    # Verify we have logged data for at least min_steps
    if scalar_tags:
        # Get the first available metric to check step count
        first_metric = scalar_tags[0]
        events = event_acc.Scalars(first_metric)
        if len(events) < min_steps:
            return f"Expected at least {min_steps} steps logged, but found {len(events)}"

    return None