File size: 4,205 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""tilelli.utils.runtime — small training-loop niceties.

Two helpers, both designed to be no-ops on hardware that doesn't need them:

  - ``ThermalGuard``: poll Linux thermal zones and sleep when the chip
    gets too hot. Lets you train overnight on a fanless laptop without
    cooking the silicon. On a host with no readable thermal zones this
    becomes a free no-op.

  - ``polite_training``: yield the CPU to the rest of the system once
    per training step. Keeps the machine usable while a long training
    run is in flight. On CUDA hosts it does almost nothing (a malloc
    trim and a gc); on CPU hosts it adds a sub-millisecond sleep.

Neither helper is required to run Tilelli — they exist so that small
hosts (cheap laptop, single-board computer, $200 mini-PC) can run a
training session without becoming unusable for everything else.
"""
from __future__ import annotations

import gc
import time
from dataclasses import dataclass
from pathlib import Path

import torch


def _read_thermal_max_celsius() -> float | None:
    """Return the hottest readable thermal zone in °C, or None if no
    /sys/class/thermal/* zones are present (most non-Linux hosts)."""
    try:
        zones = sorted(Path("/sys/class/thermal").glob("thermal_zone*/temp"))
    except OSError:
        return None
    if not zones:
        return None
    temps: list[float] = []
    for z in zones:
        try:
            temps.append(int(z.read_text().strip()) / 1000.0)
        except (OSError, ValueError):
            continue
    return max(temps) if temps else None


@dataclass
class ThermalGuard:
    """Polls the hottest thermal zone and sleeps when it crosses a cap.

    Usage in a training loop::

        guard = ThermalGuard(high_c=80.0, resume_c=72.0)
        for step in range(steps):
            guard.maybe_throttle(step)
            train_step(...)

    Parameters
    ----------
    high_c : float
        Start throttling at or above this temperature.
    resume_c : float
        Stop throttling only once temperature falls back below this.
        Must be lower than ``high_c`` to avoid threshold sawtooth.
    cool_down_s : float
        How long to sleep per throttle cycle before re-reading.
    check_every : int
        Poll every N training steps (avoid reading /sys every step;
        thermal changes are slow relative to a training step).
    """

    high_c: float = 85.0
    resume_c: float = 75.0
    cool_down_s: float = 2.0
    check_every: int = 20
    _throttling: bool = False
    _total_throttle_s: float = 0.0
    _throttle_events: int = 0
    _last_temp_c: float | None = None
    _available: bool | None = None

    def __post_init__(self) -> None:
        if self.resume_c >= self.high_c:
            raise ValueError(
                f"resume_c ({self.resume_c}) must be < high_c ({self.high_c})"
            )

    @property
    def available(self) -> bool:
        if self._available is None:
            self._available = _read_thermal_max_celsius() is not None
        return self._available

    def maybe_throttle(self, step: int) -> None:
        if not self.available:
            return
        if step % self.check_every != 0 and not self._throttling:
            return
        t = _read_thermal_max_celsius()
        if t is None:
            return
        self._last_temp_c = t
        if not self._throttling and t >= self.high_c:
            self._throttling = True
            self._throttle_events += 1
        while self._throttling:
            time.sleep(self.cool_down_s)
            self._total_throttle_s += self.cool_down_s
            t2 = _read_thermal_max_celsius()
            if t2 is None or t2 < self.resume_c:
                self._throttling = False
                self._last_temp_c = t2
                break


def polite_training() -> None:
    """Yield the CPU briefly and trim allocators. Cheap nicety so a long
    CPU run doesn't make the machine unusable for everything else."""
    gc.collect()
    if not torch.cuda.is_available():
        try:
            import ctypes

            libc = ctypes.CDLL("libc.so.6")
            libc.malloc_trim(0)
        except Exception:
            pass
        time.sleep(0.001)