PyTorch Estimator


Introduction

Analytics Zoo Orca PyTorch Estimator provides a set APIs for running PyTorch model on Spark in a distributed fashion.

Remarks:


Orca PyTorch Estimator

Orca PyTorch Estimator is an estimator to do PyTorch training/evaluation/prediction on Spark in a distributed fashion.

It can support various data types, like XShards, PyTorch DataLoader, PyTorch DataLoader creator, etc.

It supports horovod backend and BigDL backend in the unified APIs.

Create Estimator from pyTorch Model

You can create Orca PyTorch Estimator with native PyTorch model.

from zoo.orca.learn.pytorch import Estimator
Estimator.from_torch(*,
                   model,
                   optimizer,
                   loss=None,
                   scheduler_creator=None,
                   training_operator_cls=TrainingOperator,
                   initialization_hook=None,
                   config=None,
                   scheduler_step_freq="batch",
                   use_tqdm=False,
                   workers_per_node=1,
                   model_dir=None,
                   backend="bigdl"):

Use horovod Estimator

Train model

After an Estimator is created, you can call estimator API to train PyTorch model:

fit(self, data, epochs=1, profile=False, reduce_results=True, info=None)

Evaluate model

After Training, you can call estimator API to evaluate PyTorch model:

evaluate(self, data, num_steps=None, profile=False, info=None)

Get model

You can get the trained model using get_model(self)

Save model

You can save model using save(self, checkpoint) * checkpoint: (str) Path to target checkpoint file.

Load model

You can load saved model using load(self, checkpoint) * checkpoint: (str) Path to target checkpoint file.

Shutdown workers

You can shut down workers and releases resources using shutdown(self, force=False)

Use BigDL Estimator

Train model

After an Estimator is created, you can call estimator API to train PyTorch model:

fit(self, data, epochs=1, batch_size=32, validation_data=None, validation_methods=None, checkpoint_trigger=None):

Evaluate model

After Training, you can call estimator API to evaluate PyTorch model:

evaluate(self, data, validation_methods=None, batch_size=32)

Get model

You can get model using get_model(self)

Load model

You can load saved model using load(self, checkpoint, loss=None) * checkpoint: (str) Path to target checkpoint file. * loss: PyTorch loss function.

Clear gradient clipping

You can clear gradient clipping parameters using clear_gradient_clipping(self). In this case, gradient clipping will not be applied. Note: In order to take effect, it needs to be called before fit.

Set constant gradient clipping

You can Set constant gradient clipping during the training process using set_constant_gradient_clipping(self, min, max). * min: The minimum value to clip by. * max: The maximum value to clip by. Note: In order to take effect, it needs to be called before fit.

Set clip gradient to a maximum L2-Norm

You can set clip gradient to a maximum L2-Norm during the training process using set_l2_norm_gradient_clipping(self, clip_norm). * clip_norm: Gradient L2-Norm threshold. Note: In order to take effect, it needs to be called before fit.