Source code for zoo.orca.learn.pytorch.utils

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Copyright 2017 The Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is adapted from
# https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/utils.py

import collections
from contextlib import closing, contextmanager
import logging
import numpy as np
import socket
import time


logger = logging.getLogger(__name__)

BATCH_COUNT = "batch_count"
NUM_SAMPLES = "num_samples"
BATCH_SIZE = "*batch_size"


[docs]class TimerStat: """A running stat for conveniently logging the duration of a code block. Note that this class is *not* thread-safe. Examples: Time a call to 'time.sleep'. >>> import time >>> sleep_timer = TimerStat() >>> with sleep_timer: ... time.sleep(1) >>> round(sleep_timer.mean) 1 """ def __init__(self, window_size=10): self._window_size = window_size self._samples = [] self._units_processed = [] self._start_time = None self._total_time = 0.0 self.count = 0 def __enter__(self): assert self._start_time is None, "concurrent updates not supported" self._start_time = time.time() def __exit__(self, type, value, tb): assert self._start_time is not None time_delta = time.time() - self._start_time self.push(time_delta) self._start_time = None
[docs] def push(self, time_delta): self._samples.append(time_delta) if len(self._samples) > self._window_size: self._samples.pop(0) self.count += 1 self._total_time += time_delta
[docs] def push_units_processed(self, n): self._units_processed.append(n) if len(self._units_processed) > self._window_size: self._units_processed.pop(0)
@property def mean(self): return np.mean(self._samples) @property def median(self): return np.median(self._samples) @property def sum(self): return np.sum(self._samples) @property def max(self): return np.max(self._samples) @property def first(self): return self._samples[0] if self._samples else None @property def last(self): return self._samples[-1] if self._samples else None @property def size(self): return len(self._samples) @property def mean_units_processed(self): return float(np.mean(self._units_processed)) @property def mean_throughput(self): time_total = sum(self._samples) if not time_total: return 0.0 return sum(self._units_processed) / time_total
[docs] def reset(self): self._samples = [] self._units_processed = [] self._start_time = None self._total_time = 0.0 self.count = 0
@contextmanager def _nullcontext(enter_result=None): """Used for mocking timer context.""" yield enter_result
[docs]class TimerCollection: """A grouping of Timers.""" def __init__(self): self._timers = collections.defaultdict(TimerStat) self._enabled = True
[docs] def disable(self): self._enabled = False
[docs] def enable(self): self._enabled = True
[docs] def reset(self): for timer in self._timers.values(): timer.reset()
[docs] def record(self, key): if self._enabled: return self._timers[key] else: return _nullcontext()
[docs] def stats(self, mean=True, last=False): aggregates = {} for k, t in self._timers.items(): if t.count > 0: if mean: aggregates["mean_%s_s" % k] = t.mean if last: aggregates["last_%s_s" % k] = t.last return aggregates
[docs]def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1]
[docs]class AverageMeter: """Computes and stores the average and current value.""" def __init__(self): self.reset()
[docs] def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0
[docs] def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count
[docs]class AverageMeterCollection: """A grouping of AverageMeters.""" def __init__(self): self._batch_count = 0 self.n = 0 self._meters = collections.defaultdict(AverageMeter)
[docs] def update(self, metrics, n=1): self._batch_count += 1 self.n += n for metric, value in metrics.items(): self._meters[metric].update(value, n=n)
[docs] def summary(self): """Returns a dict of average and most recent values for each metric.""" stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n} for metric, meter in self._meters.items(): stats[str(metric)] = meter.avg stats["last_" + str(metric)] = meter.val return stats
[docs]def check_for_failure(remote_values): """Checks remote values for any that returned and failed. Args: remote_values (list): List of object IDs representing functions that may fail in the middle of execution. For example, running a SGD training loop in multiple parallel actor calls. Returns: Bool for success in executing given remote tasks. """ import ray from ray.exceptions import RayActorError unfinished = remote_values try: while len(unfinished) > 0: finished, unfinished = ray.wait(unfinished) finished = ray.get(finished) return True except RayActorError as exc: logger.exception(str(exc)) return False
[docs]def override(interface_class): def overrider(method): assert (method.__name__ in dir(interface_class)) return method return overrider