zoo.orca.learn.pytorch package

Submodules

zoo.orca.learn.pytorch.constants module

zoo.orca.learn.pytorch.estimator module

class zoo.orca.learn.pytorch.estimator.Estimator[source]

Bases: object

evaluate(data, **kwargs)[source]
fit(data, epochs, **kwargs)[source]
static from_torch(*, model, optimizer, loss=None, scheduler_creator=None, training_operator_cls=<class 'zoo.orca.learn.pytorch.training_operator.TrainingOperator'>, initialization_hook=None, config=None, scheduler_step_freq='batch', use_tqdm=False, workers_per_node=1, model_dir=None, backend='horovod')[source]
get_model()[source]
load(checkpoint)[source]
predict(data, **kwargs)[source]
save(checkpoint)[source]
shutdown(force=False)[source]
class zoo.orca.learn.pytorch.estimator.PyTorchHorovodEstimatorWrapper(*, model_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=<class 'zoo.orca.learn.pytorch.training_operator.TrainingOperator'>, initialization_hook=None, config=None, scheduler_step_freq='batch', use_tqdm=False, workers_per_node=1)[source]

Bases: zoo.orca.learn.pytorch.estimator.Estimator

evaluate(data, num_steps=None, profile=False, info=None)[source]
Parameters
  • data – (callable) a funtion that takes a config dict as input and return a data loader containing the validation data.

  • num_steps – (int) Number of batches to compute update steps on. This corresponds also to the number of times TrainingOperator.validate_batch is called.

  • profile – (bool) Returns time stats for the evaluation procedure.

  • info – (dict) Optional dictionary passed to the training operator for validate and validate_batch.

Returns

A dictionary of metrics for validation. You can provide custom metrics by passing in a custom training_operator_cls.

fit(data, epochs=1, profile=False, reduce_results=True, info=None)[source]
Parameters
  • data – (callable) a funtion that takes a config dict as input and return a data loader containing the training data.

  • epochs – (int) Number of epochs to train the model

  • profile – (bool) Returns time stats for the training procedure.

  • reduce_results – (bool) Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts.

  • info – (dict) Optional dictionary passed to the training operator for train_epoch and train_batch.

Returns

(list) A list of stats whose length will be equal to epochs. stats is a dictionary of metrics for training.

You can provide custom metrics by passing in a custom training_operator_cls. If reduce_results=False, this will return a list of metric dictionaries whose length will be equal to num_workers.

get_model()[source]

Returns the learned model(s).

load(checkpoint)[source]

Loads the Estimator and all workers from the provided checkpoint.

Parameters

checkpoint – (str) Path to target checkpoint file.

predict(data, **kwargs)[source]
save(checkpoint)[source]

Saves the Estimator state to the provided checkpoint path.

Parameters

checkpoint – (str) Path to target checkpoint file.

shutdown(force=False)[source]

Shuts down workers and releases resources.

class zoo.orca.learn.pytorch.estimator.PytorchSparkEstimatorWrapper(model, loss, optimizer, model_dir=None, bigdl_type='float')[source]

Bases: zoo.orca.learn.pytorch.estimator.Estimator

clear_gradient_clipping()[source]

Clear gradient clipping parameters. In this case, gradient clipping will not be applied. In order to take effect, it needs to be called before fit. :return:

evaluate(data, validation_methods=None, batch_size=32)[source]
fit(data, epochs=1, batch_size=32, validation_data=None, validation_methods=None, checkpoint_trigger=None)[source]
get_model()[source]
load(checkpoint)[source]
predict(data, **kwargs)[source]
save(checkpoint)[source]
set_constant_gradient_clipping(min, max)[source]

Set constant gradient clipping during the training process. In order to take effect, it needs to be called before fit. :param min: The minimum value to clip by. :param max: The maximum value to clip by. :return:

set_l2_norm_gradient_clipping(clip_norm)[source]

Clip gradient to a maximum L2-Norm during the training process. In order to take effect, it needs to be called before fit. :param clip_norm: Gradient L2-Norm threshold. :return:

shutdown(force=False)[source]

zoo.orca.learn.pytorch.pytorch_horovod_estimator module

zoo.orca.learn.pytorch.pytorch_trainer module

zoo.orca.learn.pytorch.torch_runner module

class zoo.orca.learn.pytorch.torch_runner.TorchRunner(model_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=None, config=None, use_tqdm=False, scheduler_step_freq=None)[source]

Bases: object

Manages a PyTorch model for training.

apply(fn)[source]
apply_operator(fn)[source]
find_free_port()[source]

Finds a free port on the current node.

get_node_ip()[source]

Returns the IP address of the current node.

given_models
given_optimizers
given_schedulers
load_state_dict(state)[source]

Sets the state of the model.

load_state_stream(byte_obj)[source]

Loads a bytes object the training state dict.

setup_components()[source]

Runs the creator functions without any distributed coordination.

setup_operator()[source]

Create the training operator.

static should_wrap_dataloader(loader)[source]
shutdown()[source]

Attempts to shut down the worker.

state_dict()[source]

Returns the state of the runner.

state_stream()[source]

Returns a bytes object for the state dict.

train_epoch(data_loader, profile=False, info=None)[source]

Runs a training epoch and updates the model parameters.

train_epochs(data_creator, epochs=1, profile=False, info=None)[source]
validate(data_creator, num_steps=None, profile=False, info=None)[source]

Evaluates the model on the validation data set.

with_sampler(loader)[source]

zoo.orca.learn.pytorch.training_operator module

class zoo.orca.learn.pytorch.training_operator.TrainingOperator(config, models, optimizers, world_rank, criterion=None, schedulers=None, device_ids=None, use_gpu=False, use_fp16=False, use_tqdm=False)[source]

Bases: object

Abstract class for custom training or validation loops.

The scheduler will only be called at a batch or epoch frequency, depending on the user parameter. Be sure to set scheduler_step_freq in TorchTrainer to either “batch” or “epoch” to increment the scheduler correctly during training. If using a learning rate scheduler that depends on validation loss, you can use trainer.update_scheduler.

For both training and validation, there are two granularities that you can provide customization: per epoch or per batch. You do not need to override both.

raysgd-custom.jpg
Raises

ValueError if multiple models/optimizers/schedulers are provided. – You are expected to subclass this class if you wish to train over multiple models/optimizers/schedulers.

config

Provided into TorchTrainer.

Type

dict

criterion

Criterion created by the provided loss_creator.

device

The appropriate torch device, at your convenience.

Type

torch.device

device_ids

Device IDs for the model.

This is useful for using batch norm with DistributedDataParallel.

Type

List[int]

load_state_dict(state_dict)[source]

Override this to load the representation of the operator state.

Parameters

state_dict (dict) – State dict as returned by the operator.

model

First or only model created by the provided model_creator.

models

List of models created by the provided model_creator.

optimizer

First or only optimizer(s) created by the optimizer_creator.

optimizers

List of optimizers created by the optimizer_creator.

scheduler

First or only scheduler(s) created by the scheduler_creator.

schedulers

List of schedulers created by the scheduler_creator.

setup(config)[source]

Override this method to implement custom operator setup.

Parameters

config (dict) – Custom configuration value to be passed to all creator and operator constructors. Same as self.config.

state_dict()[source]

Override this to return a representation of the operator state.

Returns

The state dict of the operator.

Return type

dict

train_batch(batch, batch_info)[source]

Computes loss and updates the model over one batch.

This method is responsible for computing the loss and gradient and updating the model.

By default, this method implementation assumes that batches are in (*features, labels) format. So we also support multiple inputs model. If using amp/fp16 training, it will also scale the loss automatically.

You can provide custom loss metrics and training operations if you override this method. If overriding this method, you can access model, optimizer, criterion via self.model, self.optimizer, and self.criterion.

You do not need to override this method if you plan to override train_epoch.

Parameters
  • batch – One item of the validation iterator.

  • batch_info (dict) – Information dict passed in from train_epoch.

Returns

A dictionary of metrics.

By default, this dictionary contains “loss” and “num_samples”. “num_samples” corresponds to number of datapoints in the batch. However, you can provide any number of other values. Consider returning “num_samples” in the metrics because by default, train_epoch uses “num_samples” to calculate averages.

train_epoch(iterator, info)[source]

Runs one standard training pass over the training dataloader.

By default, this method will iterate over the given iterator and call self.train_batch over each batch. If scheduler_step_freq is set, this default method will also step the scheduler accordingly.

You do not need to call train_batch in this method if you plan to implement a custom optimization/training routine here.

You may find ray.util.sgd.utils.AverageMeterCollection useful when overriding this method. See example below:

def train_epoch(self, ...):
    meter_collection = AverageMeterCollection()
    self.model.train()
    for batch in iterator:
        # do some processing
        metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics

        # This keeps track of all metrics across multiple batches
        meter_collection.update(metrics, n=len(batch))

    # Returns stats of the meters.
    stats = meter_collection.summary()
    return stats
Parameters
  • iterator (iter) – Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed.

  • info (dict) – Dictionary for information to be used for custom training operations.

Returns

A dict of metrics from training.

use_fp16

Whether the model and optimizer have been FP16 enabled.

Type

bool

use_gpu

Returns True if cuda is available and use_gpu is True.

use_tqdm

Whether tqdm progress bars are enabled.

Type

bool

validate(val_iterator, info)[source]

Runs one standard validation pass over the val_iterator.

This will call model.eval() and torch.no_grad when iterating over the validation dataloader.

If overriding this method, you can access model, criterion via self.model and self.criterion. You also do not need to call validate_batch if overriding this method.

Parameters
  • val_iterator (iter) – Iterable constructed from the validation dataloader.

  • info – (dict): Dictionary for information to be used for custom validation operations.

Returns

A dict of metrics from the evaluation.

By default, returns “val_accuracy” and “val_loss” which is computed by aggregating “loss” and “correct” values from validate_batch and dividing it by the sum of num_samples from all calls to self.validate_batch.

validate_batch(batch, batch_info)[source]

Calcuates the loss and accuracy over a given batch.

You can override this method to provide arbitrary metrics.

Same as train_batch, this method implementation assumes that batches are in (*features, labels) format by default. So we also support multiple inputs model.

Parameters
  • batch – One item of the validation iterator.

  • batch_info (dict) – Contains information per batch from validate().

Returns

A dict of metrics.

By default, returns “val_loss”, “val_accuracy”, and “num_samples”. When overriding, consider returning “num_samples” in the metrics because by default, validate uses “num_samples” to calculate averages.

world_rank

The rank of the parent runner. Always 0 if not distributed.

Type

int

zoo.orca.learn.pytorch.utils module

class zoo.orca.learn.pytorch.utils.AverageMeter[source]

Bases: object

Computes and stores the average and current value.

reset()[source]
update(val, n=1)[source]
class zoo.orca.learn.pytorch.utils.AverageMeterCollection[source]

Bases: object

A grouping of AverageMeters.

summary()[source]

Returns a dict of average and most recent values for each metric.

update(metrics, n=1)[source]
class zoo.orca.learn.pytorch.utils.TimerCollection[source]

Bases: object

A grouping of Timers.

disable()[source]
enable()[source]
record(key)[source]
reset()[source]
stats(mean=True, last=False)[source]
class zoo.orca.learn.pytorch.utils.TimerStat(window_size=10)[source]

Bases: object

A running stat for conveniently logging the duration of a code block.

Note that this class is not thread-safe.

Examples

Time a call to ‘time.sleep’.

>>> import time
>>> sleep_timer = TimerStat()
>>> with sleep_timer:
...     time.sleep(1)
>>> round(sleep_timer.mean)
1
first
last
max
mean
mean_throughput
mean_units_processed
median
push(time_delta)[source]
push_units_processed(n)[source]
reset()[source]
size
sum
zoo.orca.learn.pytorch.utils.check_for_failure(remote_values)[source]

Checks remote values for any that returned and failed.

Parameters

remote_values (list) – List of object IDs representing functions that may fail in the middle of execution. For example, running a SGD training loop in multiple parallel actor calls.

Returns

Bool for success in executing given remote tasks.

zoo.orca.learn.pytorch.utils.find_free_port()[source]
zoo.orca.learn.pytorch.utils.override(interface_class)[source]

Module contents