# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
import logging
import os
import queue
import re
import sys
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional
class NodeLocalTmpLogHandler(logging.Handler):
"""Custom log handler that logs messages to temporary files on local node storage."""
def __init__(
self,
rank_id: Optional[int],
file_path: str,
max_file_size: int,
max_backup_files: int,
proc_name: str,
):
super().__init__()
self.fname = None
self.flock = threading.Lock()
self.rank_id = rank_id
self.file_path = file_path
self.max_file_size = max_file_size
self.max_backup_files = max_backup_files
self.proc_name = proc_name
def emit(self, record: logging.LogRecord):
"""Emit a log record."""
try:
# Format the message using the formatter (which handles rank info dynamically)
msg = self.format(record)
self._write_message(message=msg)
except (OSError, IOError, RuntimeError):
# Fallback to stderr if logging fails
sys.stderr.write(f"Log handler error: {record.getMessage()}\n")
sys.stderr.flush()
def _get_backup_files(self):
"""Return sorted list of backup files for this rank/process."""
rank_str = str(self.rank_id) if self.rank_id is not None else "unknown"
file_prefix = f"rank_{rank_str}_{self.proc_name}.msg."
backup_files = [
filename
for filename in os.listdir(self.file_path)
if re.match(rf"{file_prefix}(\d+)", filename)
]
backup_files.sort()
return backup_files
def _log_file_namer(self):
backup_files = self._get_backup_files()
if self.fname is None and backup_files:
return backup_files[0]
rank_str = str(self.rank_id) if self.rank_id is not None else "unknown"
file_prefix = f"rank_{rank_str}_{self.proc_name}.msg."
return f"{file_prefix}{int(time.time()*1000)}"
def _cleanup_old_backup_files(self):
"""Clean up old log files, keeping only the most recent ones."""
backup_files = self._get_backup_files()
for old_file in backup_files[: -self.max_backup_files]:
try:
os.remove(os.path.join(self.file_path, old_file))
except (OSError, IOError) as e:
sys.stderr.write(f"Failed to remove backup file {old_file}: {e}\n")
sys.stderr.flush()
def _write_message(self, message: str):
with self.flock:
if self.fname is None:
os.makedirs(self.file_path, exist_ok=True)
self.fname = os.path.join(self.file_path, self._log_file_namer())
# Check if file needs rotation
if os.path.exists(self.fname):
try:
file_size = os.path.getsize(self.fname)
if file_size > self.max_file_size:
self.fname = os.path.join(self.file_path, self._log_file_namer())
self._cleanup_old_backup_files()
except (OSError, IOError) as e:
sys.stderr.write(f"File rotation error for {self.fname}: {e}\n")
sys.stderr.flush()
# Append message to the rank's message file
with open(self.fname, 'a') as f:
f.write(f"{message}\n")
f.flush() # Ensure message is written immediately
class DynamicLogFormatter(logging.Formatter):
"""Dynamic formatter that reads rank information from LogManager."""
def __init__(
self,
workload_rank=None,
workload_local_rank=None,
infra_rank=None,
infra_local_rank=None,
fmt=None,
datefmt=None,
):
super().__init__(fmt, datefmt)
self.workload_rank = workload_rank
self.workload_local_rank = workload_local_rank
self.infra_rank = infra_rank
self.infra_local_rank = infra_local_rank
def format(self, record):
# Fallback to "?" for None values
record.workload_rank = self.workload_rank if self.workload_rank is not None else "?"
record.workload_local_rank = (
self.workload_local_rank if self.workload_local_rank is not None else "?"
)
record.infra_rank = self.infra_rank if self.infra_rank is not None else "?"
record.infra_local_rank = (
self.infra_local_rank if self.infra_local_rank is not None else "?"
)
# Use the parent's format method
return super().format(record)
class LogMessage:
"""Represents a log message."""
log_pattern = re.compile(
r"(?P<asctime>[\d-]+\s[\d:,]+) \[(?P<levelname>\w+)\] \[(?P<hostname>[\w.-]+)\] "
r"\[workload:(?P<workload_rank>\d+)\((?P<workload_local_rank>\d+)\) infra:(?P<infra_rank>\d+)\((?P<infra_local_rank>\d+)\)\] "
r"(?P<filename>[\w.]+):(?P<lineno>\d+) (?P<message>.+)"
)
def __init__(self, log_message: str):
self.log_message = log_message
self.hash_table = {}
match = LogMessage.log_pattern.match(log_message)
if match:
log_fields = match.groupdict()
for key, value in log_fields.items():
if key == 'asctime':
# Convert asctime to a datetime object, then to a Unix timestamp
dt = datetime.strptime(value, '%Y-%m-%d %H:%M:%S,%f')
timestamp = int(dt.timestamp())
self.hash_table[key] = timestamp
else:
self.hash_table[key] = value
if 'asctime' not in self.hash_table:
current_datetime = datetime.now()
self.hash_table['asctime'] = int(current_datetime.timestamp())
def getts(self):
return self.hash_table['asctime']
def __str__(self):
return self.log_message
[docs]
class NodeLogAggregator:
[docs]
def __init__(
self, log_dir: str, temp_dir: str, log_file: str, max_file_size: int, en_chrono_ord: bool
):
self._log_dict_queue = {}
self._aggregator_thread = None
self._stop_event = threading.Event()
self._max_msg_file_size = max_file_size
# Use node_id to ensure all ranks on the same node use the same directory
self._temp_dir = temp_dir
os.makedirs(self._temp_dir, exist_ok=True)
# Create log directory if it doesn't exist
self._log_dir = log_dir
os.makedirs(self._log_dir, exist_ok=True)
self._log_file = log_file
self.en_chrono_ord = en_chrono_ord
# Track file positions for each rank to avoid re-reading
self._file_positions = {}
[docs]
def shutdown(self):
self._stop_event.set()
if self._aggregator_thread:
self._aggregator_thread.join()
self._aggregator_thread = None
[docs]
def start_aggregator(self):
"""Start the log aggregator thread."""
if self._aggregator_thread is not None:
return
self._aggregator_thread = threading.Thread(
target=self._aggregator_loop, daemon=True, name="LogAggregator"
)
self._aggregator_thread.start()
def _write_messages_to_file(self, messages: List[LogMessage], output):
# Write messages to output
for msg in messages:
try:
# The message is already formatted by the formatter, just write it
output.write(msg.log_message + '\n')
output.flush()
except Exception as e:
# Fallback to stderr if output fails
sys.stderr.write(f"Log output error: {e}\n")
sys.stderr.flush()
def _merge_sort_streaming_lists(
self, msg_dict: Dict[str, queue.SimpleQueue], heap: List
) -> list:
if not self.en_chrono_ord:
unsorted_msgs = []
for key, msg_q in msg_dict.items():
if msg_q:
while not msg_q.empty():
lm = msg_q.get()
unsorted_msgs.append(lm)
msg_dict.clear()
return unsorted_msgs
# Initialize heap with the first log of each list
heap_keys = {}
i = 0
for _, key, lm in heap:
heap_keys[key] = i
i += 1
for key, msg_q in msg_dict.items():
if msg_q and msg_q.qsize() > 0:
if key not in heap_keys:
lm = msg_q.get()
# push <ts, key, log>
heapq.heappush(heap, (lm.getts(), key, lm))
sorted_msgs = []
while heap:
ts, key, log_entry = heapq.heappop(heap)
sorted_msgs.append(log_entry)
msg_q = msg_dict[key]
if msg_q.qsize() > 0:
next_log = msg_q.get()
heapq.heappush(heap, (next_log.getts(), key, next_log))
else:
break
return sorted_msgs
def _process_messages(self, output):
# Check for pending messages from other ranks
keep_processing = 50
msg_dict = {}
heap = []
while keep_processing:
if self._stop_event.is_set():
# Gives room for aggregator to catch up with writes
if len(heap) == 0:
keep_processing -= 1
# Check for pending messages from other ranks
self._check_pending_messages()
# Process queued messages
for key, lm_q in self._log_dict_queue.items():
if key in msg_dict:
curr_q = msg_dict[key]
while not lm_q.empty():
curr_q.put(lm_q.get())
else:
msg_dict[key] = lm_q
self._log_dict_queue.clear()
sorted_msgs = self._merge_sort_streaming_lists(msg_dict, heap)
if len(sorted_msgs) > 0:
self._write_messages_to_file(sorted_msgs, output)
# Sleep briefly to avoid busy waiting
time.sleep(0.025)
def _aggregator_loop(self):
"""Main loop for the log aggregator."""
# Setup per-node log file
log_file = os.path.join(self._log_dir, self._log_file)
output = open(log_file, 'a', buffering=1) # Line buffered
try:
self._process_messages(output)
finally:
output.close()
def _check_pending_messages(self):
if not os.path.exists(self._temp_dir):
return
# Check if we can access the directory
if not os.access(self._temp_dir, os.R_OK):
return
# Look for message files from all ranks (including this aggregator rank)
for filename in os.listdir(self._temp_dir):
if not filename.startswith('rank_'):
continue
msg_file = os.path.join(self._temp_dir, filename)
# Process current file
self._process_message_file(msg_file)
def _process_message_file(self, msg_file: str):
"""Process a single message file (current or backup)."""
try:
file_size = os.path.getsize(msg_file)
except FileNotFoundError as e:
# File was deleted/renamed between discovery and processing
# This can happen due to race conditions, but should be logged for debugging
sys.stderr.write(f"File not found during processing {msg_file}: {e}\n")
sys.stderr.flush()
return
except (IOError, OSError) as e:
# Unexpected: Permission issues, disk problems, etc.
# Log this as it might indicate a real problem
sys.stderr.write(f"Unexpected error accessing {msg_file}: {e}\n")
sys.stderr.flush()
return
# Get the last known position for this file
last_position = self._file_positions.get(msg_file, 0)
# If file hasn't grown, check if can be deleted
if file_size <= last_position and file_size >= self._max_msg_file_size:
self._cleanup_old_backup_files(os.path.basename(msg_file))
return
# Read new content from the file
try:
with open(msg_file, 'r') as f:
f.seek(last_position)
lines = f.readlines()
file_size = f.tell()
except FileNotFoundError as e:
# File was deleted between size check and read
sys.stderr.write(f"File not found during read {msg_file}: {e}\n")
sys.stderr.flush()
return
except (IOError, OSError) as e:
# File is being written by another process or other I/O error
# Log this as it might indicate a real problem
sys.stderr.write(f"IO error reading {msg_file}: {e}\n")
sys.stderr.flush()
return
# Process each line
log_msg_q = queue.SimpleQueue()
for line in lines:
line = line.strip()
if not line:
continue
log_msg = LogMessage(line)
log_msg_q.put(log_msg)
self._log_dict_queue[msg_file] = log_msg_q
# Update the position for this file
self._file_positions[msg_file] = file_size
def _cleanup_old_backup_files(self, msg_file: str):
"""Clean up old backup files, keeping only the most recent one."""
# Find all backup files for this rank
parts_first = msg_file.split('.', 1)
parts_last = msg_file.rsplit('.', 1)
if len(parts_first) < 2 or len(parts_last) < 2:
sys.stderr.write(f"Skipping '{msg_file}': missing '.' parts")
return
to_del_ts = parts_last[-1]
if not to_del_ts.isdigit():
sys.stderr.write(f"Skipping '{msg_file}': last part is not numeric")
return
to_del_prefix = parts_first[0]
if not os.path.exists(self._temp_dir):
return
for filename in os.listdir(self._temp_dir):
match = re.match(rf"{to_del_prefix}.msg\.(\d+)", filename)
if not match:
continue
cur_file_ts = match.group(1)
if int(cur_file_ts) > int(to_del_ts):
try:
os.remove(os.path.join(self._temp_dir, msg_file))
break
except (OSError, IOError) as e:
# Log the error but don't fail the entire operation
sys.stderr.write(f"Failed to remove backup file {msg_file}: {e}\n")
sys.stderr.flush()