Module stoke.stoke
API interface to Stoke that handles any necessary config, context, setup etc.
None
??? example "View Source" # -- coding: utf-8 --
# Copyright FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0
"""API interface to Stoke that handles any necessary config, context, setup etc."""
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from uuid import uuid4
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import ShardedDataParallel as SDDP
from torch.nn.parallel import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset
from torch.utils.data.distributed import Sampler
from stoke.configs import (
AMPConfig,
ApexConfig,
ClipGradConfig,
ClipGradNormConfig,
DDPConfig,
DeepspeedConfig,
FairscaleFSDPConfig,
FairscaleOSSConfig,
FairscaleSDDPConfig,
HorovodConfig,
StokeOptimizer,
)
from stoke.data import StokeDataLoader
from stoke.distributed import RunnerDistEnum
from stoke.extensions import RunnerOptimizerEnum
from stoke.fp16 import RunnerFP16Enum
from stoke.io import RunnerIOEnum
from stoke.status import DistributedOptions, FP16Options, StokeStatus
from stoke.utils import (
ParamNormalize,
T_co,
_collate_fn_t,
_worker_init_fn_t,
zero_optimizer_grads,
)
class Stoke:
"""High level stoke object that manages all necessary configs and provides a unified interface to ops
This is the main class within Stoke. Functionally it manages all interfaces to the necessary wrapped ops (model,
loss, backward, step), provides helper functions, and dynamically constructs the runtime that handles the
combinatorics problem of underlying frameworks (DDP, Horovod, Deepspeed, Fairscale),
mixed-precision (AMP or APEX) and devices (CPU or GPU)
Attributes
----------
amp_config
apex_config
batch_size
cuda
ddp_config
deepspeed_config
distributed
effective_batch_size
ema_loss
fp16
fsdp_config
fully_sharded
gpu
grad_accum
grad_clip
horovod_config
is_amp
is_apex
is_ddp
is_deepspeed
is_horovod
loss_access
model_access
nccl
num_model_parameters
optimizer
oss
oss_config
rank
scaler
sddp_config
sharded
status
world_size
_agg_loss: Union[float, List[float], Tuple[float]]
aggregated loss for grad accumulation (single or multiple losses)
_backward_steps: int
Number of times gradients have been calculated on a batch of samples (calls to backward)
_grad_accum_counter: int
counter for grad accumulation steps
_loss: Union[Callable, List[Callable], Tuple[Callable]]
callable function that calculates a loss from the model outputs
_last_step_loss: list, tuple, or float
last loss step calculation aggregated over device(s)
_model: torch.nn.Module
instance of torch.nn.Module for Stoke to handle
_optimizer: StokeOptimizer
StokeOptimizer config object that describes the torch.optim.Optimizer and it's kwargs
_optimizer_steps: int
Number of times step has been called on the optimizer
_runner: StokeRunner
the dynamically created runtime object that handles all ops
_status: StokeStatus
StokeStatus object that sets and maintains the current configuration
_verbose: bool
print verbosity
_rolling_loss_steps: int
number of steps that have been called for the rolling loss
_rolling_mean_loss: list, tuple, or float
current ema loss
_ema_weight: float
weight used for any ema calculation on metrics
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: StokeOptimizer,
loss: Union[Callable, List[Callable], Tuple[Callable]],
batch_size_per_device: int,
grad_accum_steps: Optional[int] = 1,
grad_clip: Optional[Union[ClipGradConfig, ClipGradNormConfig]] = None,
gpu: bool = False,
fp16: Optional[FP16Options] = None,
distributed: Optional[DistributedOptions] = None,
fairscale_oss: bool = False,
fairscale_sddp: bool = False,
fairscale_fsdp: bool = False,
configs: Optional[
List[
Union[
AMPConfig,
ApexConfig,
DDPConfig,
DeepspeedConfig,
FairscaleOSSConfig,
FairscaleSDDPConfig,
FairscaleFSDPConfig,
HorovodConfig,
]
]
] = None,
info_rank: Optional[Union[int, List[int]]] = 0,
verbose: bool = True,
ema_weight: float = 0.1,
):
"""Init for Stoke class object
Parameters
----------
model: torch.nn.Module
PyTorch model
optimizer: StokeOptimizer
Optimizer configuration
loss: Union[Callable, List[Callable], Tuple[Callable]]
Callable loss function or functions
batch_size_per_device: int
Batch size at the single device level
grad_accum_steps: Optional[int], default: 1
Number of gradient accumulation steps
grad_clip: Optional[Union[ClipGradConfig, ClipGradNormConfig]], default: None
Gradient clipping configuration
gpu: bool, default: False
flag to use GPU device(s)
fp16: Optional[FP16Options], default: None
Choice of mixed-precision backend
distributed: Optional[DistributedOptions], default: None
Choice of distributed backend
fairscale_oss: bool, default: False
Flag to activate optimizer state sharding using Fairscale
fairscale_sddp: bool, default: False
Flag to activate sharded DDP using Fairscale
fairscale_fsdp: bool, default: False
Flag to activate fully sharded DDP using Fairscale
configs: Optional[List[Union[AMPConfig, ApexConfig, DDPConfig, DeepspeedConfig, FairscaleOSSConfig, FairscaleSDDPConfig, FairscaleFSDPConfig, HorovodConfig]], default: None
Configuration objects for runtimes
info_rank: Optional[Union[int, List[int]]], default = 0
Constrain prints to specific devices
verbose: bool, default: True
Flag for verbosity
ema_weight: float, default: 0.5
weight used for any ema calculation on metrics
"""
# Verbosity
self._verbose = verbose
# Info rank
self._info_rank = info_rank
# EMA
self._ema_weight = ema_weight
# Setup the StokeState
self._status = StokeStatus(
batch_size_per_device=batch_size_per_device,
grad_accum=grad_accum_steps,
grad_clip=grad_clip,
gpu=gpu,
fp16=fp16,
distributed=distributed,
fairscale_oss=fairscale_oss,
fairscale_sddp=fairscale_sddp,
fairscale_fsdp=fairscale_fsdp,
configs=configs,
)
# Run some checks
self._model = self._check_model(model)
self._optimizer = self._check_optimizer(optimizer)
self._loss = self._check_loss(loss)
# Dynamically construct the StokeRunner from the StokeStatus
self._runner, class_info = self._build_runner()
# Setup distributed backend
self._runner.setup_distributed()
# Post here the runner will have the print_device function that is mapped to the self.print here
# as it needs rank to be accessible before working
if self._verbose:
dev_id = (
self.rank
if (self.rank == "cpu" or self.rank == "gpu")
else self._info_rank
)
self.print(f"Printing verbose information on rank(s): {dev_id}")
# Print the runner class info from the mixins
self.print(class_info)
# Possibly place model on GPU depending on StokeStatus -- before wrap calls
self._place_model_on_gpu()
# Handle the wrap ops in the correct order
self._handle_ordered_wrap_ops(optimizer=optimizer)
# Create some tracking vars
self._grad_accum_counter = 0
self._optimizer_steps = 0
self._backward_steps = 0
self._last_step_loss = self._set_loss_to_zero()
self._agg_loss = self._set_loss_to_zero()
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
# Set post-init status variables
self._status.set_post_init_values(world_size=self.world_size)
# Print the final configuration
if self._verbose:
self.print(msg=self._status)
def _wrap_optimizer_then_model(self, optimizer: StokeOptimizer):
"""Handles wrapping of optimizer then the model
This holds only for SDDP, Horovod, and APEX as these need to use an instantiated optimizer before wrapped
methods are called
Parameters
----------
optimizer: StokeOptimizer
Optimizer configuration
Returns
-------
None
"""
# Build the optimizer
self._optimizer = self._runner.build_optimizer(
optimizer=optimizer["optimizer"],
optimizer_kwargs=optimizer["optimizer_kwargs"],
model=self._model,
)
# Setup/Initialize FP16 backend -- in this case the optimizer is passed through
self._runner.wrap_fp16(model=self._model, optimizer=self._optimizer)
# Wrap with distributed backend -- in this case the optimizer is passed through
self._model, self._optimizer = self._runner.wrap_distributed(
model=self._model, grad_accum=self.grad_accum, optimizer=self._optimizer
)
def _wrap_model_then_optimizer(self, optimizer: StokeOptimizer):
"""Handles wrapping of model then optimizer
Parameters
----------
optimizer: StokeOptimizer
Optimizer configuration
Returns
-------
None
"""
# Wrap with distributed backend -- in this case the optimizer is passed as None since it doesn't exist yet
# don't use the return for the optimizer in this case
self._model, _ = self._runner.wrap_distributed(
model=self._model, grad_accum=self.grad_accum, optimizer=None
)
# Setup/Initialize FP16 backend -- in this case the optimizer is passed as None since it doesn't exist yet
self._runner.wrap_fp16(model=self._model, optimizer=None)
# Build the optimizer
self._optimizer = self._runner.build_optimizer(
optimizer=optimizer["optimizer"],
optimizer_kwargs=optimizer["optimizer_kwargs"],
model=self._model,
)
def _handle_ordered_wrap_ops(self, optimizer: StokeOptimizer):
"""Handles wrapping model, using FP16, and wrapping optimizer in the correct order depending on Stoke Status
Parameters
----------
optimizer: StokeOptimizer
Optimizer configuration
Returns
-------
None
"""
# if SDDP + OSS, Horovod, and APEX then we need to make sure that the optimizer gets wrapped before the model
# gets wrapped, all other models follow standard DDP paradigm (or their own DeepSpeed)
if (self.sharded and self.oss) or self.is_apex or self.is_horovod:
self._wrap_optimizer_then_model(optimizer=optimizer)
else:
self._wrap_model_then_optimizer(optimizer=optimizer)
def _check_accum(self):
"""Checks if the current step is the last accumulation step
Returns
-------
bool
"""
return (self._grad_accum_counter + 1) % (self.grad_accum + 1) == 0
def _check_pre_accum(self):
"""Checks if we are at the pre-accumulate step
Returns
-------
bool
"""
return (self._grad_accum_counter + 1) % (self.grad_accum + 1) == self.grad_accum
def _set_loss_to_zero(self):
"""Used to set a loss tracker to zero depending on the type
Returns
-------
float or list or tuple of reset loss
"""
return (
type(self._loss)([0.0] * len(self._loss))
if isinstance(self._loss, (list, tuple))
else 0.0
)
def reset_ema(self):
"""Used to reset the current state of the rolling mean loss
Returns
-------
None
"""
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
def print_ema_loss(
self, prepend_msg: str = "Current EMA Loss", single_line: bool = False
):
"""Prints the current ema loss synced across all devices
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
prepend_msg: str, default: "Current EMA Loss"
message prepend to print
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
if isinstance(self._rolling_mean_loss, (list, tuple)):
print_vals = [
f"{prepend_msg} {idx}: {val:.3f}"
for idx, val in enumerate(self._rolling_mean_loss)
]
self.print(print_vals, single_line=single_line)
else:
self.print(f"{prepend_msg}: {self._rolling_mean_loss:.3f}")
def print_mean_accumulated_synced_loss(
self,
prepend_msg: str = "Mean Accumulated & Synced Loss",
pre_backwards: bool = True,
single_line: bool = False,
):
"""Prints the mean accumulated and device synced loss only after the grad accumulation step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
prepend_msg: str, default: "Mean Accumulated & Synced Loss"
message prepend to print
pre_backwards: bool, default: True
if being called pre backward step
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
check_fn = self._check_pre_accum if pre_backwards else self._check_accum
if check_fn():
if isinstance(self._agg_loss, (list, tuple)):
print_vals = self._scale_agg_loss()
self.print(print_vals, single_line=single_line)
else:
self.print(f"{prepend_msg}: {self._scale_agg_loss():.3f}")
def _scale_agg_loss(self):
"""Scales the mean aggregated loss by grad accum
Returns
-------
scale_vals: list or float of mean aggregated loss
"""
if isinstance(self._agg_loss, (list, tuple)):
scale_vals = [
val / self.grad_accum for idx, val in enumerate(self._agg_loss)
]
else:
scale_vals = self._agg_loss / self.grad_accum
return scale_vals
def print_synced_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
prepend_msg: str = "Step Synced Loss",
device=None,
single_line: bool = False,
):
"""Prints a device synced loss at a single step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
current loss(es) on the device
prepend_msg: str, default: "Step Synced Loss"
message prepend to print
device: default: None
specify the device to place the synced loss on (defaults to same device)
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
printable_loss = self.detach_and_sync_loss(loss, device)
if isinstance(printable_loss, (list, tuple)):
print_vals = [
f"{prepend_msg} {idx}: {val * self.grad_accum:.3f}"
for idx, val in enumerate(printable_loss)
]
self.print(print_vals, single_line=single_line)
else:
self.print(msg=f"{prepend_msg}: {printable_loss * self.grad_accum:.3f}")
def print_on_devices(
self, msg: Union[str, List[str]], rank: Optional[Union[int, List[int]]] = 0
):
"""Wraps runner print interface for shorter semantics
Parameters
----------
msg: str
message to print
rank: Union[int, List[int]], default: 0
which ranks to print on
Returns
-------
None
"""
self._runner.print_device(msg=msg, rank=rank)
def print(self, msg: Union[str, List[str]], single_line: bool = False):
"""Wraps the runners print device and forces print on the _info_rank attribute(s)
Parameters
----------
msg: str
message to print
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
self._runner.print_device(
msg=msg, rank=self._info_rank, single_line=single_line
)
@staticmethod
def _check_model(model: torch.nn.Module):
"""Verifies the type of the model
Parameters
----------
model: torch.nn.Module
current torch model
Returns
-------
None
"""
# Check if the model is an nn.Module such that it has a forward method
if not isinstance(model, torch.nn.Module):
raise TypeError(
f"Stoke -- Model is not of type torch.nn.Module, currently {type(model)}"
)
return model
@staticmethod
def _check_optimizer(optimizer: StokeOptimizer):
"""Verifies the type of the optimizer
Parameters
----------
optimizer: StokeOptimizer
Current optimizer configuration TypedDict (aka dict)
Returns
-------
None
"""
if not isinstance(optimizer, dict):
raise TypeError(
f"Stoke -- Optimizer is not of type torch.optim.Optimizer, currently {type(optimizer)}"
)
return optimizer
def _check_loss(self, loss: Union[Callable, List[Callable], Tuple[Callable]]):
"""Checks to make sure the loss function(s) is/are callable
Parameters
----------
loss: Union[Callable, List[Callable], Tuple[Callable]]
Current callable loss(es)
Returns
-------
None
"""
if isinstance(loss, (list, tuple)):
loss = [self._check_loss(val) for val in loss]
return loss
elif isinstance(loss, Callable):
return loss
else:
raise TypeError(
f"Stoke -- Loss is not of type Callable, currently {type(loss)}"
)
def _place_model_on_gpu(self):
"""Automatically moves the model to GPU device(s)
Returns
-------
None
"""
if self.gpu and not self.is_deepspeed:
if self._verbose:
self.print(f"Automatically handling moving model to GPU(s)...")
self._model.cuda()
def _build_runner(self):
"""Builds the runtime object from the mixin style classes
Mixes the distributed class, fp16 class, and optimizer class into a single object such that all can be called
from the same interface. Prevents verbose calls to multiple objects and unifies all functionality under a
a single interface. Might prevent some IDE type-hinting as it's dynamic
Returns
-------
StokeRunner
runtime runner object
"""
# Get the classes
dist_class = self._get_distributed_mixin()
fp16_class = self._get_fp16_mixin()
optimizer_class = self._get_optimizer_mixin()
io_class = self._get_io_mixin()
# Python MRO hack to make sure the inits of all the Mixin classes get called
def __multiple_mixin_init__(*args, **kwargs):
dist_class.__init__(*args, **kwargs)
fp16_class.__init__(*args, **kwargs)
optimizer_class.__init__(*args, **kwargs)
io_class.__init__(*args, **kwargs)
# Configs pass through
kwargs_dict = {
"amp_config": self.amp_config,
"apex_config": self.apex_config,
"ddp_config": self.ddp_config,
"deepspeed_config": self.deepspeed_config,
"horovod_config": self.horovod_config,
"oss_config": self.oss_config,
"sharded_config": self.sddp_config,
"fully_sharded_config": self.fsdp_config,
}
# Generate the runner class from the mixins based on the StokeStatus
runner_class = type(
"StokeRunner",
(dist_class, fp16_class, optimizer_class, io_class),
{"__init__": __multiple_mixin_init__},
)(
verbose=self._verbose,
batch_size_per_device=self.batch_size,
grad_accum_steps=self.grad_accum,
grad_clip=self.grad_clip,
info_rank=self._info_rank,
loss=self._loss,
**kwargs_dict,
)
# Make a list of class info for print later
class_info = [
f"Distributed Mixin: {dist_class.__name__}",
f"Optimizer Mixin: {dist_class.__name__}",
f"FP16 Mixin: {fp16_class.__name__}",
f"IO Mixin: {io_class.__name__}",
]
return runner_class, class_info
def _get_io_mixin(self):
"""Determines which IO class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated ioclass
"""
if self.is_deepspeed:
return_class = RunnerIOEnum.deepspeed.value
elif self.is_horovod:
return_class = RunnerIOEnum.horovod.value
elif self.is_ddp:
return_class = RunnerIOEnum.ddp.value
else:
return_class = RunnerIOEnum.base.value
return return_class
def _get_optimizer_mixin(self):
"""Determines which optimizer class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated optimizer class
"""
if self.oss:
return_class = RunnerOptimizerEnum.oss.value
else:
return_class = RunnerOptimizerEnum.base.value
return return_class
def _get_distributed_mixin(self):
"""Determines which distributed class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated distributed class
"""
# if not gpu then fall to cpu single
if not self.gpu:
return_class = RunnerDistEnum.cpu.value
# if gpu but no distributed then fall to single gpu
elif self.gpu and (self.distributed is None):
return_class = RunnerDistEnum.gpu.value
elif self.gpu and (self.distributed is not None):
return_class = RunnerDistEnum[self.distributed].value
else:
raise ValueError("Stoke -- Cannot map to a valid distributed class")
return return_class
def _get_fp16_mixin(self):
"""Determines which fp16 class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated fp16 class
"""
if self.fp16 is not None:
return_class = RunnerFP16Enum[self.fp16].value
else:
return_class = RunnerFP16Enum.full.value
return return_class
def DataLoader(
self,
dataset: Dataset[T_co],
shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0,
collate_fn: _collate_fn_t = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
):
"""Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.
Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called)
and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus
object is available which is post init. This could be disconnected from this class but it would require the
user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and
never handled
Parameters
----------
dataset: Dataset
dataset from which to load the data.
shuffle: bool, default: False
set to ``True`` to have the data reshuffled at every epoch.
sampler: Sampler or Iterable, default: None
defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler: Sampler or Iterable, default: None:
like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers: int, default: 0
how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.
collate_fn: callable, optional:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory: bool, default: False:
If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your
data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last: bool, default: False
set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
will be smaller.
timeout: numeric, default: 0
if positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn: callable, default: None
If not ``None``, this will be called on each worker subprocess with the worker id
(an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.
prefetch_factor: int, default: 2
Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers
samples prefetched across all workers.
persistent_workers: bool, default: False
If ``True``, the data loader will not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers `Dataset` instances alive.
Returns
-------
StokeDataLoader
wrapped torch.utils.data.DataLoader object
"""
# Check if forkserver is available for horovod and use
if (
num_workers > 0
and hasattr(torch.multiprocessing, "_supports_context")
and torch.multiprocessing._supports_context
and "forkserver" in torch.multiprocessing.get_all_start_methods()
and self.is_horovod
):
multiprocessing_context = "forkserver"
if self._verbose and self.gpu:
print(f"Automatically handling moving model input data to GPU(s)...")
# Forward the already known options from the Stoke status
return StokeDataLoader(
gpu=self.gpu,
fp16=self.fp16,
batch_size=self.batch_size,
dataset=dataset,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
def model(self, *args, **kwargs):
"""Wrapped model forward call
Parameters
----------
*args: list or tuple
Additional arguments should be passed as keyword arguments
**kwargs: dict, optional
Extra arguments passed to the model forward call
Returns
-------
model forward output
"""
with self._runner.model_context:
return self._model(*args, **kwargs)
# return self.model_access(*args, **kwargs)
def loss(self, *args, **kwargs):
"""Wrapped callable loss function call
Handles internal logic of aggregating up the losses for single and multiple losses
Parameters
----------
*args: list or tuple
Additional arguments should be passed as keyword arguments
**kwargs: dict, optional
Extra arguments passed to the loss function call(s)
Returns
-------
outputs of callable loss function(s)
"""
# TODO: WIP Handle multiple losses. Should support list/tuple of losses. Check non base PyTorch
with self._runner.loss_context:
if isinstance(self._loss, (list, tuple)):
loss = type(self._loss)(val(*args, **kwargs) for val in self._loss)
sync_loss = [self.detach_and_sync_loss(val) for val in loss]
self._last_step_loss = type(self._loss)(
val for idx, val in enumerate(sync_loss)
)
self._agg_loss = type(self._loss)(
self._agg_loss[idx] + val for idx, val in enumerate(sync_loss)
)
self._handle_ema_loss(loss=sync_loss)
if self.grad_accum > 1 and self.model_access.training:
loss = type(loss)(val / self.grad_accum for val in loss)
else:
loss = self._loss(*args, **kwargs)
sync_loss = self.detach_and_sync_loss(loss)
self._last_step_loss = sync_loss
self._agg_loss += sync_loss
self._handle_ema_loss(loss=sync_loss)
# Handle grad accumulation by dividing by the accumulation steps
if self.grad_accum > 1 and self.model_access.training:
loss = loss / self.grad_accum
return loss
def _handle_ema_loss(self, loss: Union[float, List[float], Tuple[float]]):
"""Handles calculating the ema loss
Parameters
----------
loss: Union[float, List[float], Tuple[float]]
current calculated loss list, tuple or float
Returns
-------
None
"""
self._rolling_loss_steps += 1
if isinstance(loss, (list, tuple)):
self._rolling_mean_loss = type(self._rolling_mean_loss)(
self._ema_loss(value=val, current_mean=self._rolling_mean_loss[idx])
for idx, val in enumerate(loss)
)
else:
self._rolling_mean_loss = self._ema_loss(
value=loss, current_mean=self._rolling_mean_loss
)
def _ema_loss(self, value: float, current_mean: float):
"""Calculate the ema of the loss
Parameters
----------
value: float
current loss value
current_mean: float
current mean value
Returns
-------
current ema value: float
"""
if self._rolling_loss_steps == 1:
return value
else:
return (self._ema_weight * value) + (
(1.0 - self._ema_weight) * current_mean
)
def backward(
self, loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
):
"""Wrapped backwards call
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
Callable loss function(s)
Returns
-------
None
"""
# Increment the grad counter
self._grad_accum_counter += 1
# Set the context based on the counter
dist_cm = (
nullcontext()
if self._check_accum()
else self._runner.grad_accum_context(self._model)
)
with dist_cm:
self._runner.backward_call(
loss=loss, model=self.model_access, optimizer=self._optimizer
)
# Increment the number of total calls to backward (each backward to a loss is only considered 1)
self._backward_steps += 1
def step(self):
"""Wrapped step call
Handles grad clipping internally
Returns
-------
None
"""
# Step the optimizer only if the modulo is zero
if self._check_accum():
if self._verbose and self.grad_accum > 0:
self.print(f"Gradient Accumulation Steps: {self.grad_accum}")
# Clip if needed
if self.grad_clip is not None:
self._runner.clip_grad(
self.grad_clip,
self._model if self.fully_sharded else self.model_access,
self._optimizer,
oss=self.oss,
horovod=self.is_horovod,
deepspeed=self.is_deepspeed,
fsdp=self.fully_sharded,
)
# Handle the optimizer step
step_cm = (
self._runner.step_context(self._optimizer)
if self.grad_clip is not None
else nullcontext()
)
with step_cm:
self._runner.step_call(
model=self.model_access, optimizer=self._optimizer
)
# Reset for the accumulated step
self._reset()
# Increment the number of step calls to the optimizer
self._optimizer_steps += 1
# if deepspeed we need to step everytime as it handles the grad accumulation internally
elif self.is_deepspeed:
# Handle the optimizer step
step_cm = (
self._runner.step_context(self._optimizer)
if self.grad_clip is not None
else nullcontext()
)
with step_cm:
self._runner.step_call(
model=self.model_access, optimizer=self._optimizer
)
def _reset(self):
"""Resets the state post optimizer step call
Returns
-------
None
"""
if self._verbose:
self.print("Resetting all grad/variables for next optimizer step")
# Zero the grads if not deepspeed
if not self.is_deepspeed:
self.zero_grads()
# Reset counter
self._grad_accum_counter = 0
# Reset agg loss -- single or mutiple losses
self._agg_loss = self._set_loss_to_zero()
def save(
self,
path: str,
name: str = uuid4(),
extension: str = "pt",
create_directory: bool = True,
extras: Optional[dict] = None,
):
"""Saves a model checkpoint using the correct backend interface
Parameters
----------
path: str
path to directory to save the model checkpoint (prefer absolute paths over relative paths)
name: str, default: uuid4()
name used to save checkpoint file
extension: str, default: '.pt'
extension used to save PyTorch model checkpoint
create_directory: bool, default: True
flag to create the directory path if it doesn't exist
extras: dict, default: None
a dictionary of any extra things to save
Returns
-------
path: str
path to directory that the model checkpoint was saved
tag: str
full tag name the model checkpoint was saved as
"""
out_path, tag = self._runner.save(
model=self._model if self.fully_sharded else self.model_access,
optimizer=self.optimizer,
path=path,
backward_step=self._backward_steps,
grad_accum_step=self._grad_accum_counter,
optimizer_step=self._optimizer_steps,
name=name,
scaler_dict=self.fp16_state_dict,
extension=extension,
create_directory=create_directory,
extras=extras,
status=self.status.status,
)
self.print(f"Successfully saved model checkpoint to {out_path}/{tag}")
return out_path, tag
def load(self, path: str, tag: str, strict: bool = True):
"""Loads a model checkpoint using the correct backend interface
Parameters
----------
path: str
path to directory that the model checkpoint was saved (prefer absolute paths over relative paths)
tag: str
full tag name the model checkpoint was saved as
strict: bool
ignore non-matching keys
Returns
-------
extras: dict, default: None
a dictionary of any custom fields the user passed to the save function
"""
# TODO: How to deal with mapping between backends? e.g. FP16 model back to FP32? Or multi-gpu to CPU?
backward_step, grad_accum_step, optimizer_step, extras = self._runner.load(
model=self._model if self.fully_sharded else self.model_access,
optimizer=self.optimizer,
gpu=self.gpu,
path=path,
tag=tag,
scaler_dict_fn=self._load_fp16_state_dict_fn(),
strict=strict,
)
# Reset values based on what was in the load dict
self._backward_steps = backward_step
self._grad_accum_counter = grad_accum_step
self._optimizer_steps = optimizer_step
self.print(f"Successfully loaded model checkpoint from {path}/{tag}")
# Return the extras dict
return extras
def print_num_model_parameters(
self, normalize: ParamNormalize = ParamNormalize.MILLION
):
"""
Parameters
----------
normalize: ParamNormalize, default: ParamNormalize.MILLION
ParamNormalize choice for pretty print normalizing
Returns
-------
None
"""
self.print(
f"Total Trainable Model Parameters: "
f"{(self.num_model_parameters / normalize.value):.3f} {normalize.name}"
)
def detach_and_sync_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
device=None,
):
"""Shorthand method to detach and sync loss
Maps to the runner function of the same name
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
current loss(es)
device: default: None
device to sync across
Returns
-------
loss that is synced across devices and all_reduced w/ SUM
"""
return self._runner.detach_and_sync_loss(loss=loss, device=device)
def zero_grads(self):
"""Zeros the optimizer grads depending on the optimizer type
Returns
-------
None
"""
zero_optimizer_grads(
optimizer=self._optimizer, apex=self.is_apex, horovod=self.is_horovod
)
def reset(self):
"""Public method for resetting the underlying stoke state
Returns
-------
None
"""
self._reset()
def reset_tracking(self):
"""Public method for resetting all underlying stoke tracked variables
Returns
-------
None
"""
# Create some tracking vars
self._grad_accum_counter = 0
self._optimizer_steps = 0
self._backward_steps = 0
self._last_step_loss = self._set_loss_to_zero()
self._agg_loss = self._set_loss_to_zero()
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
def dump_model_parameter_info(self):
"""Dumps all parameter information for named parameters (shape, device, dtype)
Returns
-------
None
"""
self.print("Dumping all model parameter information to stdout....")
for name, param in self.model_access.named_parameters():
if param.requires_grad:
self.print(
f"Name: {name}, Shape: {param.shape}, "
f"Device: {param.device}, dtype: {param.dtype}"
)
def _load_fp16_state_dict_fn(self):
"""Returns the function to load the sacler state dict
Returns
-------
mp_state_dict_fn: Callable, default: None
callable function to load the scaler state dict
"""
mp_state_dict_fn = None
if self.scaler is not None:
if self.is_apex:
try:
from apex import amp
mp_state_dict_fn = amp.load_state_dict
except ImportError as e:
print(
e,
": Stoke -- apex cannot be imported -- please install (https://github.com/NVIDIA/apex)",
)
else:
mp_state_dict_fn = self.scaler.load_state_dict
return mp_state_dict_fn
def barrier(self):
"""Calls the underlying distributed barrier if available"""
self._runner.barrier()
@property
def step_loss(self):
"""Gets the last step loss synced across device(s) (unscaled)"""
return self._last_step_loss
@property
def model_access(self):
"""Interface for model access due to the different types between the DP, DDP, and SDDP implementations"""
if isinstance(self._model, (DDP, DP, SDDP, FSDP)):
return self._model.module
else:
return self._model
@property
def loss_access(self):
"""Gets loss tensor(s)"""
return self._loss
@property
def optimizer(self):
"""Gets the optimizer"""
return self._optimizer
@property
def scaler(self):
"""Gets the current scaler object"""
return self._runner.scaler
@property
def fp16_state_dict(self):
"""Gets the fp16 state dict from various methods"""
mp_state_dict = None
if self.scaler is not None:
if self.is_apex:
try:
from apex import amp
mp_state_dict = amp.state_dict()
except ImportError as e:
print(
e,
": Stoke -- apex cannot be imported -- please install (https://github.com/NVIDIA/apex)",
)
elif self.is_amp:
mp_state_dict = self.scaler.state_dict()
return mp_state_dict
@property
def status(self):
"""Gets the StokeStatus object"""
return self._status
@property
def batch_size(self):
"""Shortcut to batch size"""
return self._status.batch_size
@property
def effective_batch_size(self):
"""Shortcut to effective batch size"""
return self._status.effective_batch_size
@property
def grad_clip(self):
"""Shortcut to get grad clip"""
return self._status.grad_clip
@property
def grad_accum(self):
"""Shortcut to get grad accumulation"""
return self._status.grad_accum
@property
def gpu(self):
"""Shortcut to get GPU status"""
return self._status.gpu
@property
def cuda(self):
"""Shortcut to get cuda status"""
return self._status.cuda
@property
def nccl(self):
"""Shortcut to get nccl status"""
return self._status.nccl
@property
def fp16(self):
"""Shortcut to get FP16 status"""
return self._status.fp16
@property
def is_apex(self):
"""Returns if APEX is activated"""
return self._status.is_fp16_apex
@property
def is_amp(self):
"""Returns if AMP is activated"""
return self._status.is_fp16_amp
@property
def distributed(self):
"""Shortcut to distributed status"""
return self._status.distributed
@property
def is_ddp(self):
"""Returns if DDP is activated"""
return self._status.is_distributed_ddp
@property
def is_horovod(self):
"""Returns if Horovod is activated"""
return self._status.is_distributed_horovod
@property
def is_deepspeed(self):
"""Returns if Deepspeed is acticated"""
return self._status.is_distributed_deepspeed
@property
def oss(self):
"""Returns if Fairscale optimizer state sharding status"""
return self._status.oss
@property
def sharded(self):
"""Returns if Fairscale sharded DDP status"""
return self._status.sharded
@property
def fully_sharded(self):
"""Returns if Fairscale fully sharded DDP status"""
return self._status.fully_sharded
@property
def world_size(self):
"""Shortcut to get world size"""
return self._runner.world_size
@property
def rank(self):
"""Shortcut to get rank"""
return self._runner.rank
@property
def amp_config(self):
"""Returns amp config or None based on amp state"""
return self._status.amp_config if self.is_amp else None
@property
def apex_config(self):
"""Returns apex config or None based on apex state"""
return self._status.apex_config if self.is_apex else None
@property
def ddp_config(self):
"""Returns ddp config or None based on ddp state"""
return self._status.ddp_config if self.is_ddp else None
@property
def deepspeed_config(self):
"""Returns deepspeed config or None based on deepspeed state"""
return self._status.deepspeed_config if self.is_deepspeed else None
@property
def oss_config(self):
"""Returns oss config or None based on ossstate"""
return self._status.oss_config if self.oss else None
@property
def sddp_config(self):
"""Returns sddp config or None based on sddp state"""
return self._status.sddp_config if self.sharded else None
@property
def fsdp_config(self):
"""Returns fsdp config or None based on fsdp state"""
return self._status.fsdp_config if self.fully_sharded else None
@property
def horovod_config(self):
"""Returns horovod config or None based on horovod state"""
return self._status.horovod_config if self.is_horovod else None
@property
def num_model_parameters(self):
"""Returns number of parameters that require gradients"""
return sum(p.numel() for p in self.model_access.parameters() if p.requires_grad)
@property
def ema_loss(self):
"""Returns the current rolling mean loss"""
return self._rolling_mean_loss
Classes
Stoke
class Stoke(
model: torch.nn.modules.module.Module,
optimizer: stoke.configs.StokeOptimizer,
loss: Union[Callable, List[Callable], Tuple[Callable]],
batch_size_per_device: int,
grad_accum_steps: Union[int, NoneType] = 1,
grad_clip: Union[stoke.configs.ClipGradConfig, stoke.configs.ClipGradNormConfig, NoneType] = None,
gpu: bool = False,
fp16: Union[stoke.status.FP16Options, NoneType] = None,
distributed: Union[stoke.status.DistributedOptions, NoneType] = None,
fairscale_oss: bool = False,
fairscale_sddp: bool = False,
fairscale_fsdp: bool = False,
configs: Union[List[Union[stoke.configs.AMPConfig, stoke.configs.ApexConfig, stoke.configs.DDPConfig, stoke.configs.DeepspeedConfig, stoke.configs.FairscaleOSSConfig, stoke.configs.FairscaleSDDPConfig, stoke.configs.FairscaleFSDPConfig, stoke.configs.HorovodConfig]], NoneType] = None,
info_rank: Union[int, List[int], NoneType] = 0,
verbose: bool = True,
ema_weight: float = 0.1
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
amp_config | None | None | None |
apex_config | None | None | None |
batch_size | None | None | None |
cuda | None | None | None |
ddp_config | None | None | None |
deepspeed_config | None | None | None |
distributed | None | None | None |
effective_batch_size | None | None | None |
ema_loss | None | None | None |
fp16 | None | None | None |
fsdp_config | None | None | None |
fully_sharded | None | None | None |
gpu | None | None | None |
grad_accum | None | None | None |
grad_clip | None | None | None |
horovod_config | None | None | None |
is_amp | None | None | None |
is_apex | None | None | None |
is_ddp | None | None | None |
is_deepspeed | None | None | None |
is_horovod | None | None | None |
loss_access | None | None | None |
model_access | None | None | None |
nccl | None | None | None |
num_model_parameters | None | None | None |
optimizer | None | None | None |
oss | None | None | None |
oss_config | None | None | None |
rank | None | None | None |
scaler | None | None | None |
sddp_config | None | None | None |
sharded | None | None | None |
status | None | None | None |
world_size | None | None | None |
_agg_loss | Union[float, List[float], Tuple[float]] | aggregated loss for grad accumulation (single or multiple losses) | None |
_backward_steps | int | Number of times gradients have been calculated on a batch of samples (calls to backward) | None |
_grad_accum_counter | int | counter for grad accumulation steps | None |
_loss | Union[Callable, List[Callable], Tuple[Callable]] | callable function that calculates a loss from the model outputs | None |
_last_step_loss | list, tuple, or float | last loss step calculation aggregated over device(s) | None |
_model | torch.nn.Module | instance of torch.nn.Module for Stoke to handle | None |
_optimizer | StokeOptimizer | StokeOptimizer config object that describes the torch.optim.Optimizer and it's kwargs | None |
_optimizer_steps | int | Number of times step has been called on the optimizer | None |
_runner | StokeRunner | the dynamically created runtime object that handles all ops | None |
_status | StokeStatus | StokeStatus object that sets and maintains the current configuration | None |
_verbose | bool | print verbosity | None |
_rolling_loss_steps | int | number of steps that have been called for the rolling loss | None |
_rolling_mean_loss | list, tuple, or float | current ema loss | None |
_ema_weight | float | weight used for any ema calculation on metrics | None |
??? example "View Source" class Stoke:
"""High level stoke object that manages all necessary configs and provides a unified interface to ops
This is the main class within Stoke. Functionally it manages all interfaces to the necessary wrapped ops (model,
loss, backward, step), provides helper functions, and dynamically constructs the runtime that handles the
combinatorics problem of underlying frameworks (DDP, Horovod, Deepspeed, Fairscale),
mixed-precision (AMP or APEX) and devices (CPU or GPU)
Attributes
----------
amp_config
apex_config
batch_size
cuda
ddp_config
deepspeed_config
distributed
effective_batch_size
ema_loss
fp16
fsdp_config
fully_sharded
gpu
grad_accum
grad_clip
horovod_config
is_amp
is_apex
is_ddp
is_deepspeed
is_horovod
loss_access
model_access
nccl
num_model_parameters
optimizer
oss
oss_config
rank
scaler
sddp_config
sharded
status
world_size
_agg_loss: Union[float, List[float], Tuple[float]]
aggregated loss for grad accumulation (single or multiple losses)
_backward_steps: int
Number of times gradients have been calculated on a batch of samples (calls to backward)
_grad_accum_counter: int
counter for grad accumulation steps
_loss: Union[Callable, List[Callable], Tuple[Callable]]
callable function that calculates a loss from the model outputs
_last_step_loss: list, tuple, or float
last loss step calculation aggregated over device(s)
_model: torch.nn.Module
instance of torch.nn.Module for Stoke to handle
_optimizer: StokeOptimizer
StokeOptimizer config object that describes the torch.optim.Optimizer and it's kwargs
_optimizer_steps: int
Number of times step has been called on the optimizer
_runner: StokeRunner
the dynamically created runtime object that handles all ops
_status: StokeStatus
StokeStatus object that sets and maintains the current configuration
_verbose: bool
print verbosity
_rolling_loss_steps: int
number of steps that have been called for the rolling loss
_rolling_mean_loss: list, tuple, or float
current ema loss
_ema_weight: float
weight used for any ema calculation on metrics
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: StokeOptimizer,
loss: Union[Callable, List[Callable], Tuple[Callable]],
batch_size_per_device: int,
grad_accum_steps: Optional[int] = 1,
grad_clip: Optional[Union[ClipGradConfig, ClipGradNormConfig]] = None,
gpu: bool = False,
fp16: Optional[FP16Options] = None,
distributed: Optional[DistributedOptions] = None,
fairscale_oss: bool = False,
fairscale_sddp: bool = False,
fairscale_fsdp: bool = False,
configs: Optional[
List[
Union[
AMPConfig,
ApexConfig,
DDPConfig,
DeepspeedConfig,
FairscaleOSSConfig,
FairscaleSDDPConfig,
FairscaleFSDPConfig,
HorovodConfig,
]
]
] = None,
info_rank: Optional[Union[int, List[int]]] = 0,
verbose: bool = True,
ema_weight: float = 0.1,
):
"""Init for Stoke class object
Parameters
----------
model: torch.nn.Module
PyTorch model
optimizer: StokeOptimizer
Optimizer configuration
loss: Union[Callable, List[Callable], Tuple[Callable]]
Callable loss function or functions
batch_size_per_device: int
Batch size at the single device level
grad_accum_steps: Optional[int], default: 1
Number of gradient accumulation steps
grad_clip: Optional[Union[ClipGradConfig, ClipGradNormConfig]], default: None
Gradient clipping configuration
gpu: bool, default: False
flag to use GPU device(s)
fp16: Optional[FP16Options], default: None
Choice of mixed-precision backend
distributed: Optional[DistributedOptions], default: None
Choice of distributed backend
fairscale_oss: bool, default: False
Flag to activate optimizer state sharding using Fairscale
fairscale_sddp: bool, default: False
Flag to activate sharded DDP using Fairscale
fairscale_fsdp: bool, default: False
Flag to activate fully sharded DDP using Fairscale
configs: Optional[List[Union[AMPConfig, ApexConfig, DDPConfig, DeepspeedConfig, FairscaleOSSConfig, FairscaleSDDPConfig, FairscaleFSDPConfig, HorovodConfig]], default: None
Configuration objects for runtimes
info_rank: Optional[Union[int, List[int]]], default = 0
Constrain prints to specific devices
verbose: bool, default: True
Flag for verbosity
ema_weight: float, default: 0.5
weight used for any ema calculation on metrics
"""
# Verbosity
self._verbose = verbose
# Info rank
self._info_rank = info_rank
# EMA
self._ema_weight = ema_weight
# Setup the StokeState
self._status = StokeStatus(
batch_size_per_device=batch_size_per_device,
grad_accum=grad_accum_steps,
grad_clip=grad_clip,
gpu=gpu,
fp16=fp16,
distributed=distributed,
fairscale_oss=fairscale_oss,
fairscale_sddp=fairscale_sddp,
fairscale_fsdp=fairscale_fsdp,
configs=configs,
)
# Run some checks
self._model = self._check_model(model)
self._optimizer = self._check_optimizer(optimizer)
self._loss = self._check_loss(loss)
# Dynamically construct the StokeRunner from the StokeStatus
self._runner, class_info = self._build_runner()
# Setup distributed backend
self._runner.setup_distributed()
# Post here the runner will have the print_device function that is mapped to the self.print here
# as it needs rank to be accessible before working
if self._verbose:
dev_id = (
self.rank
if (self.rank == "cpu" or self.rank == "gpu")
else self._info_rank
)
self.print(f"Printing verbose information on rank(s): {dev_id}")
# Print the runner class info from the mixins
self.print(class_info)
# Possibly place model on GPU depending on StokeStatus -- before wrap calls
self._place_model_on_gpu()
# Handle the wrap ops in the correct order
self._handle_ordered_wrap_ops(optimizer=optimizer)
# Create some tracking vars
self._grad_accum_counter = 0
self._optimizer_steps = 0
self._backward_steps = 0
self._last_step_loss = self._set_loss_to_zero()
self._agg_loss = self._set_loss_to_zero()
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
# Set post-init status variables
self._status.set_post_init_values(world_size=self.world_size)
# Print the final configuration
if self._verbose:
self.print(msg=self._status)
def _wrap_optimizer_then_model(self, optimizer: StokeOptimizer):
"""Handles wrapping of optimizer then the model
This holds only for SDDP, Horovod, and APEX as these need to use an instantiated optimizer before wrapped
methods are called
Parameters
----------
optimizer: StokeOptimizer
Optimizer configuration
Returns
-------
None
"""
# Build the optimizer
self._optimizer = self._runner.build_optimizer(
optimizer=optimizer["optimizer"],
optimizer_kwargs=optimizer["optimizer_kwargs"],
model=self._model,
)
# Setup/Initialize FP16 backend -- in this case the optimizer is passed through
self._runner.wrap_fp16(model=self._model, optimizer=self._optimizer)
# Wrap with distributed backend -- in this case the optimizer is passed through
self._model, self._optimizer = self._runner.wrap_distributed(
model=self._model, grad_accum=self.grad_accum, optimizer=self._optimizer
)
def _wrap_model_then_optimizer(self, optimizer: StokeOptimizer):
"""Handles wrapping of model then optimizer
Parameters
----------
optimizer: StokeOptimizer
Optimizer configuration
Returns
-------
None
"""
# Wrap with distributed backend -- in this case the optimizer is passed as None since it doesn't exist yet
# don't use the return for the optimizer in this case
self._model, _ = self._runner.wrap_distributed(
model=self._model, grad_accum=self.grad_accum, optimizer=None
)
# Setup/Initialize FP16 backend -- in this case the optimizer is passed as None since it doesn't exist yet
self._runner.wrap_fp16(model=self._model, optimizer=None)
# Build the optimizer
self._optimizer = self._runner.build_optimizer(
optimizer=optimizer["optimizer"],
optimizer_kwargs=optimizer["optimizer_kwargs"],
model=self._model,
)
def _handle_ordered_wrap_ops(self, optimizer: StokeOptimizer):
"""Handles wrapping model, using FP16, and wrapping optimizer in the correct order depending on Stoke Status
Parameters
----------
optimizer: StokeOptimizer
Optimizer configuration
Returns
-------
None
"""
# if SDDP + OSS, Horovod, and APEX then we need to make sure that the optimizer gets wrapped before the model
# gets wrapped, all other models follow standard DDP paradigm (or their own DeepSpeed)
if (self.sharded and self.oss) or self.is_apex or self.is_horovod:
self._wrap_optimizer_then_model(optimizer=optimizer)
else:
self._wrap_model_then_optimizer(optimizer=optimizer)
def _check_accum(self):
"""Checks if the current step is the last accumulation step
Returns
-------
bool
"""
return (self._grad_accum_counter + 1) % (self.grad_accum + 1) == 0
def _check_pre_accum(self):
"""Checks if we are at the pre-accumulate step
Returns
-------
bool
"""
return (self._grad_accum_counter + 1) % (self.grad_accum + 1) == self.grad_accum
def _set_loss_to_zero(self):
"""Used to set a loss tracker to zero depending on the type
Returns
-------
float or list or tuple of reset loss
"""
return (
type(self._loss)([0.0] * len(self._loss))
if isinstance(self._loss, (list, tuple))
else 0.0
)
def reset_ema(self):
"""Used to reset the current state of the rolling mean loss
Returns
-------
None
"""
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
def print_ema_loss(
self, prepend_msg: str = "Current EMA Loss", single_line: bool = False
):
"""Prints the current ema loss synced across all devices
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
prepend_msg: str, default: "Current EMA Loss"
message prepend to print
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
if isinstance(self._rolling_mean_loss, (list, tuple)):
print_vals = [
f"{prepend_msg} {idx}: {val:.3f}"
for idx, val in enumerate(self._rolling_mean_loss)
]
self.print(print_vals, single_line=single_line)
else:
self.print(f"{prepend_msg}: {self._rolling_mean_loss:.3f}")
def print_mean_accumulated_synced_loss(
self,
prepend_msg: str = "Mean Accumulated & Synced Loss",
pre_backwards: bool = True,
single_line: bool = False,
):
"""Prints the mean accumulated and device synced loss only after the grad accumulation step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
prepend_msg: str, default: "Mean Accumulated & Synced Loss"
message prepend to print
pre_backwards: bool, default: True
if being called pre backward step
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
check_fn = self._check_pre_accum if pre_backwards else self._check_accum
if check_fn():
if isinstance(self._agg_loss, (list, tuple)):
print_vals = self._scale_agg_loss()
self.print(print_vals, single_line=single_line)
else:
self.print(f"{prepend_msg}: {self._scale_agg_loss():.3f}")
def _scale_agg_loss(self):
"""Scales the mean aggregated loss by grad accum
Returns
-------
scale_vals: list or float of mean aggregated loss
"""
if isinstance(self._agg_loss, (list, tuple)):
scale_vals = [
val / self.grad_accum for idx, val in enumerate(self._agg_loss)
]
else:
scale_vals = self._agg_loss / self.grad_accum
return scale_vals
def print_synced_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
prepend_msg: str = "Step Synced Loss",
device=None,
single_line: bool = False,
):
"""Prints a device synced loss at a single step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
current loss(es) on the device
prepend_msg: str, default: "Step Synced Loss"
message prepend to print
device: default: None
specify the device to place the synced loss on (defaults to same device)
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
printable_loss = self.detach_and_sync_loss(loss, device)
if isinstance(printable_loss, (list, tuple)):
print_vals = [
f"{prepend_msg} {idx}: {val * self.grad_accum:.3f}"
for idx, val in enumerate(printable_loss)
]
self.print(print_vals, single_line=single_line)
else:
self.print(msg=f"{prepend_msg}: {printable_loss * self.grad_accum:.3f}")
def print_on_devices(
self, msg: Union[str, List[str]], rank: Optional[Union[int, List[int]]] = 0
):
"""Wraps runner print interface for shorter semantics
Parameters
----------
msg: str
message to print
rank: Union[int, List[int]], default: 0
which ranks to print on
Returns
-------
None
"""
self._runner.print_device(msg=msg, rank=rank)
def print(self, msg: Union[str, List[str]], single_line: bool = False):
"""Wraps the runners print device and forces print on the _info_rank attribute(s)
Parameters
----------
msg: str
message to print
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
self._runner.print_device(
msg=msg, rank=self._info_rank, single_line=single_line
)
@staticmethod
def _check_model(model: torch.nn.Module):
"""Verifies the type of the model
Parameters
----------
model: torch.nn.Module
current torch model
Returns
-------
None
"""
# Check if the model is an nn.Module such that it has a forward method
if not isinstance(model, torch.nn.Module):
raise TypeError(
f"Stoke -- Model is not of type torch.nn.Module, currently {type(model)}"
)
return model
@staticmethod
def _check_optimizer(optimizer: StokeOptimizer):
"""Verifies the type of the optimizer
Parameters
----------
optimizer: StokeOptimizer
Current optimizer configuration TypedDict (aka dict)
Returns
-------
None
"""
if not isinstance(optimizer, dict):
raise TypeError(
f"Stoke -- Optimizer is not of type torch.optim.Optimizer, currently {type(optimizer)}"
)
return optimizer
def _check_loss(self, loss: Union[Callable, List[Callable], Tuple[Callable]]):
"""Checks to make sure the loss function(s) is/are callable
Parameters
----------
loss: Union[Callable, List[Callable], Tuple[Callable]]
Current callable loss(es)
Returns
-------
None
"""
if isinstance(loss, (list, tuple)):
loss = [self._check_loss(val) for val in loss]
return loss
elif isinstance(loss, Callable):
return loss
else:
raise TypeError(
f"Stoke -- Loss is not of type Callable, currently {type(loss)}"
)
def _place_model_on_gpu(self):
"""Automatically moves the model to GPU device(s)
Returns
-------
None
"""
if self.gpu and not self.is_deepspeed:
if self._verbose:
self.print(f"Automatically handling moving model to GPU(s)...")
self._model.cuda()
def _build_runner(self):
"""Builds the runtime object from the mixin style classes
Mixes the distributed class, fp16 class, and optimizer class into a single object such that all can be called
from the same interface. Prevents verbose calls to multiple objects and unifies all functionality under a
a single interface. Might prevent some IDE type-hinting as it's dynamic
Returns
-------
StokeRunner
runtime runner object
"""
# Get the classes
dist_class = self._get_distributed_mixin()
fp16_class = self._get_fp16_mixin()
optimizer_class = self._get_optimizer_mixin()
io_class = self._get_io_mixin()
# Python MRO hack to make sure the inits of all the Mixin classes get called
def __multiple_mixin_init__(*args, **kwargs):
dist_class.__init__(*args, **kwargs)
fp16_class.__init__(*args, **kwargs)
optimizer_class.__init__(*args, **kwargs)
io_class.__init__(*args, **kwargs)
# Configs pass through
kwargs_dict = {
"amp_config": self.amp_config,
"apex_config": self.apex_config,
"ddp_config": self.ddp_config,
"deepspeed_config": self.deepspeed_config,
"horovod_config": self.horovod_config,
"oss_config": self.oss_config,
"sharded_config": self.sddp_config,
"fully_sharded_config": self.fsdp_config,
}
# Generate the runner class from the mixins based on the StokeStatus
runner_class = type(
"StokeRunner",
(dist_class, fp16_class, optimizer_class, io_class),
{"__init__": __multiple_mixin_init__},
)(
verbose=self._verbose,
batch_size_per_device=self.batch_size,
grad_accum_steps=self.grad_accum,
grad_clip=self.grad_clip,
info_rank=self._info_rank,
loss=self._loss,
**kwargs_dict,
)
# Make a list of class info for print later
class_info = [
f"Distributed Mixin: {dist_class.__name__}",
f"Optimizer Mixin: {dist_class.__name__}",
f"FP16 Mixin: {fp16_class.__name__}",
f"IO Mixin: {io_class.__name__}",
]
return runner_class, class_info
def _get_io_mixin(self):
"""Determines which IO class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated ioclass
"""
if self.is_deepspeed:
return_class = RunnerIOEnum.deepspeed.value
elif self.is_horovod:
return_class = RunnerIOEnum.horovod.value
elif self.is_ddp:
return_class = RunnerIOEnum.ddp.value
else:
return_class = RunnerIOEnum.base.value
return return_class
def _get_optimizer_mixin(self):
"""Determines which optimizer class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated optimizer class
"""
if self.oss:
return_class = RunnerOptimizerEnum.oss.value
else:
return_class = RunnerOptimizerEnum.base.value
return return_class
def _get_distributed_mixin(self):
"""Determines which distributed class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated distributed class
"""
# if not gpu then fall to cpu single
if not self.gpu:
return_class = RunnerDistEnum.cpu.value
# if gpu but no distributed then fall to single gpu
elif self.gpu and (self.distributed is None):
return_class = RunnerDistEnum.gpu.value
elif self.gpu and (self.distributed is not None):
return_class = RunnerDistEnum[self.distributed].value
else:
raise ValueError("Stoke -- Cannot map to a valid distributed class")
return return_class
def _get_fp16_mixin(self):
"""Determines which fp16 class to use
Embedded logic based on the enum class
Returns
-------
ABCMeta
un-instantiated fp16 class
"""
if self.fp16 is not None:
return_class = RunnerFP16Enum[self.fp16].value
else:
return_class = RunnerFP16Enum.full.value
return return_class
def DataLoader(
self,
dataset: Dataset[T_co],
shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0,
collate_fn: _collate_fn_t = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
):
"""Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.
Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called)
and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus
object is available which is post init. This could be disconnected from this class but it would require the
user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and
never handled
Parameters
----------
dataset: Dataset
dataset from which to load the data.
shuffle: bool, default: False
set to ``True`` to have the data reshuffled at every epoch.
sampler: Sampler or Iterable, default: None
defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler: Sampler or Iterable, default: None:
like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers: int, default: 0
how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.
collate_fn: callable, optional:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory: bool, default: False:
If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your
data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last: bool, default: False
set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
will be smaller.
timeout: numeric, default: 0
if positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn: callable, default: None
If not ``None``, this will be called on each worker subprocess with the worker id
(an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.
prefetch_factor: int, default: 2
Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers
samples prefetched across all workers.
persistent_workers: bool, default: False
If ``True``, the data loader will not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers `Dataset` instances alive.
Returns
-------
StokeDataLoader
wrapped torch.utils.data.DataLoader object
"""
# Check if forkserver is available for horovod and use
if (
num_workers > 0
and hasattr(torch.multiprocessing, "_supports_context")
and torch.multiprocessing._supports_context
and "forkserver" in torch.multiprocessing.get_all_start_methods()
and self.is_horovod
):
multiprocessing_context = "forkserver"
if self._verbose and self.gpu:
print(f"Automatically handling moving model input data to GPU(s)...")
# Forward the already known options from the Stoke status
return StokeDataLoader(
gpu=self.gpu,
fp16=self.fp16,
batch_size=self.batch_size,
dataset=dataset,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
def model(self, *args, **kwargs):
"""Wrapped model forward call
Parameters
----------
*args: list or tuple
Additional arguments should be passed as keyword arguments
**kwargs: dict, optional
Extra arguments passed to the model forward call
Returns
-------
model forward output
"""
with self._runner.model_context:
return self._model(*args, **kwargs)
# return self.model_access(*args, **kwargs)
def loss(self, *args, **kwargs):
"""Wrapped callable loss function call
Handles internal logic of aggregating up the losses for single and multiple losses
Parameters
----------
*args: list or tuple
Additional arguments should be passed as keyword arguments
**kwargs: dict, optional
Extra arguments passed to the loss function call(s)
Returns
-------
outputs of callable loss function(s)
"""
# TODO: WIP Handle multiple losses. Should support list/tuple of losses. Check non base PyTorch
with self._runner.loss_context:
if isinstance(self._loss, (list, tuple)):
loss = type(self._loss)(val(*args, **kwargs) for val in self._loss)
sync_loss = [self.detach_and_sync_loss(val) for val in loss]
self._last_step_loss = type(self._loss)(
val for idx, val in enumerate(sync_loss)
)
self._agg_loss = type(self._loss)(
self._agg_loss[idx] + val for idx, val in enumerate(sync_loss)
)
self._handle_ema_loss(loss=sync_loss)
if self.grad_accum > 1 and self.model_access.training:
loss = type(loss)(val / self.grad_accum for val in loss)
else:
loss = self._loss(*args, **kwargs)
sync_loss = self.detach_and_sync_loss(loss)
self._last_step_loss = sync_loss
self._agg_loss += sync_loss
self._handle_ema_loss(loss=sync_loss)
# Handle grad accumulation by dividing by the accumulation steps
if self.grad_accum > 1 and self.model_access.training:
loss = loss / self.grad_accum
return loss
def _handle_ema_loss(self, loss: Union[float, List[float], Tuple[float]]):
"""Handles calculating the ema loss
Parameters
----------
loss: Union[float, List[float], Tuple[float]]
current calculated loss list, tuple or float
Returns
-------
None
"""
self._rolling_loss_steps += 1
if isinstance(loss, (list, tuple)):
self._rolling_mean_loss = type(self._rolling_mean_loss)(
self._ema_loss(value=val, current_mean=self._rolling_mean_loss[idx])
for idx, val in enumerate(loss)
)
else:
self._rolling_mean_loss = self._ema_loss(
value=loss, current_mean=self._rolling_mean_loss
)
def _ema_loss(self, value: float, current_mean: float):
"""Calculate the ema of the loss
Parameters
----------
value: float
current loss value
current_mean: float
current mean value
Returns
-------
current ema value: float
"""
if self._rolling_loss_steps == 1:
return value
else:
return (self._ema_weight * value) + (
(1.0 - self._ema_weight) * current_mean
)
def backward(
self, loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
):
"""Wrapped backwards call
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
Callable loss function(s)
Returns
-------
None
"""
# Increment the grad counter
self._grad_accum_counter += 1
# Set the context based on the counter
dist_cm = (
nullcontext()
if self._check_accum()
else self._runner.grad_accum_context(self._model)
)
with dist_cm:
self._runner.backward_call(
loss=loss, model=self.model_access, optimizer=self._optimizer
)
# Increment the number of total calls to backward (each backward to a loss is only considered 1)
self._backward_steps += 1
def step(self):
"""Wrapped step call
Handles grad clipping internally
Returns
-------
None
"""
# Step the optimizer only if the modulo is zero
if self._check_accum():
if self._verbose and self.grad_accum > 0:
self.print(f"Gradient Accumulation Steps: {self.grad_accum}")
# Clip if needed
if self.grad_clip is not None:
self._runner.clip_grad(
self.grad_clip,
self._model if self.fully_sharded else self.model_access,
self._optimizer,
oss=self.oss,
horovod=self.is_horovod,
deepspeed=self.is_deepspeed,
fsdp=self.fully_sharded,
)
# Handle the optimizer step
step_cm = (
self._runner.step_context(self._optimizer)
if self.grad_clip is not None
else nullcontext()
)
with step_cm:
self._runner.step_call(
model=self.model_access, optimizer=self._optimizer
)
# Reset for the accumulated step
self._reset()
# Increment the number of step calls to the optimizer
self._optimizer_steps += 1
# if deepspeed we need to step everytime as it handles the grad accumulation internally
elif self.is_deepspeed:
# Handle the optimizer step
step_cm = (
self._runner.step_context(self._optimizer)
if self.grad_clip is not None
else nullcontext()
)
with step_cm:
self._runner.step_call(
model=self.model_access, optimizer=self._optimizer
)
def _reset(self):
"""Resets the state post optimizer step call
Returns
-------
None
"""
if self._verbose:
self.print("Resetting all grad/variables for next optimizer step")
# Zero the grads if not deepspeed
if not self.is_deepspeed:
self.zero_grads()
# Reset counter
self._grad_accum_counter = 0
# Reset agg loss -- single or mutiple losses
self._agg_loss = self._set_loss_to_zero()
def save(
self,
path: str,
name: str = uuid4(),
extension: str = "pt",
create_directory: bool = True,
extras: Optional[dict] = None,
):
"""Saves a model checkpoint using the correct backend interface
Parameters
----------
path: str
path to directory to save the model checkpoint (prefer absolute paths over relative paths)
name: str, default: uuid4()
name used to save checkpoint file
extension: str, default: '.pt'
extension used to save PyTorch model checkpoint
create_directory: bool, default: True
flag to create the directory path if it doesn't exist
extras: dict, default: None
a dictionary of any extra things to save
Returns
-------
path: str
path to directory that the model checkpoint was saved
tag: str
full tag name the model checkpoint was saved as
"""
out_path, tag = self._runner.save(
model=self._model if self.fully_sharded else self.model_access,
optimizer=self.optimizer,
path=path,
backward_step=self._backward_steps,
grad_accum_step=self._grad_accum_counter,
optimizer_step=self._optimizer_steps,
name=name,
scaler_dict=self.fp16_state_dict,
extension=extension,
create_directory=create_directory,
extras=extras,
status=self.status.status,
)
self.print(f"Successfully saved model checkpoint to {out_path}/{tag}")
return out_path, tag
def load(self, path: str, tag: str, strict: bool = True):
"""Loads a model checkpoint using the correct backend interface
Parameters
----------
path: str
path to directory that the model checkpoint was saved (prefer absolute paths over relative paths)
tag: str
full tag name the model checkpoint was saved as
strict: bool
ignore non-matching keys
Returns
-------
extras: dict, default: None
a dictionary of any custom fields the user passed to the save function
"""
# TODO: How to deal with mapping between backends? e.g. FP16 model back to FP32? Or multi-gpu to CPU?
backward_step, grad_accum_step, optimizer_step, extras = self._runner.load(
model=self._model if self.fully_sharded else self.model_access,
optimizer=self.optimizer,
gpu=self.gpu,
path=path,
tag=tag,
scaler_dict_fn=self._load_fp16_state_dict_fn(),
strict=strict,
)
# Reset values based on what was in the load dict
self._backward_steps = backward_step
self._grad_accum_counter = grad_accum_step
self._optimizer_steps = optimizer_step
self.print(f"Successfully loaded model checkpoint from {path}/{tag}")
# Return the extras dict
return extras
def print_num_model_parameters(
self, normalize: ParamNormalize = ParamNormalize.MILLION
):
"""
Parameters
----------
normalize: ParamNormalize, default: ParamNormalize.MILLION
ParamNormalize choice for pretty print normalizing
Returns
-------
None
"""
self.print(
f"Total Trainable Model Parameters: "
f"{(self.num_model_parameters / normalize.value):.3f} {normalize.name}"
)
def detach_and_sync_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
device=None,
):
"""Shorthand method to detach and sync loss
Maps to the runner function of the same name
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
current loss(es)
device: default: None
device to sync across
Returns
-------
loss that is synced across devices and all_reduced w/ SUM
"""
return self._runner.detach_and_sync_loss(loss=loss, device=device)
def zero_grads(self):
"""Zeros the optimizer grads depending on the optimizer type
Returns
-------
None
"""
zero_optimizer_grads(
optimizer=self._optimizer, apex=self.is_apex, horovod=self.is_horovod
)
def reset(self):
"""Public method for resetting the underlying stoke state
Returns
-------
None
"""
self._reset()
def reset_tracking(self):
"""Public method for resetting all underlying stoke tracked variables
Returns
-------
None
"""
# Create some tracking vars
self._grad_accum_counter = 0
self._optimizer_steps = 0
self._backward_steps = 0
self._last_step_loss = self._set_loss_to_zero()
self._agg_loss = self._set_loss_to_zero()
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
def dump_model_parameter_info(self):
"""Dumps all parameter information for named parameters (shape, device, dtype)
Returns
-------
None
"""
self.print("Dumping all model parameter information to stdout....")
for name, param in self.model_access.named_parameters():
if param.requires_grad:
self.print(
f"Name: {name}, Shape: {param.shape}, "
f"Device: {param.device}, dtype: {param.dtype}"
)
def _load_fp16_state_dict_fn(self):
"""Returns the function to load the sacler state dict
Returns
-------
mp_state_dict_fn: Callable, default: None
callable function to load the scaler state dict
"""
mp_state_dict_fn = None
if self.scaler is not None:
if self.is_apex:
try:
from apex import amp
mp_state_dict_fn = amp.load_state_dict
except ImportError as e:
print(
e,
": Stoke -- apex cannot be imported -- please install (https://github.com/NVIDIA/apex)",
)
else:
mp_state_dict_fn = self.scaler.load_state_dict
return mp_state_dict_fn
def barrier(self):
"""Calls the underlying distributed barrier if available"""
self._runner.barrier()
@property
def step_loss(self):
"""Gets the last step loss synced across device(s) (unscaled)"""
return self._last_step_loss
@property
def model_access(self):
"""Interface for model access due to the different types between the DP, DDP, and SDDP implementations"""
if isinstance(self._model, (DDP, DP, SDDP, FSDP)):
return self._model.module
else:
return self._model
@property
def loss_access(self):
"""Gets loss tensor(s)"""
return self._loss
@property
def optimizer(self):
"""Gets the optimizer"""
return self._optimizer
@property
def scaler(self):
"""Gets the current scaler object"""
return self._runner.scaler
@property
def fp16_state_dict(self):
"""Gets the fp16 state dict from various methods"""
mp_state_dict = None
if self.scaler is not None:
if self.is_apex:
try:
from apex import amp
mp_state_dict = amp.state_dict()
except ImportError as e:
print(
e,
": Stoke -- apex cannot be imported -- please install (https://github.com/NVIDIA/apex)",
)
elif self.is_amp:
mp_state_dict = self.scaler.state_dict()
return mp_state_dict
@property
def status(self):
"""Gets the StokeStatus object"""
return self._status
@property
def batch_size(self):
"""Shortcut to batch size"""
return self._status.batch_size
@property
def effective_batch_size(self):
"""Shortcut to effective batch size"""
return self._status.effective_batch_size
@property
def grad_clip(self):
"""Shortcut to get grad clip"""
return self._status.grad_clip
@property
def grad_accum(self):
"""Shortcut to get grad accumulation"""
return self._status.grad_accum
@property
def gpu(self):
"""Shortcut to get GPU status"""
return self._status.gpu
@property
def cuda(self):
"""Shortcut to get cuda status"""
return self._status.cuda
@property
def nccl(self):
"""Shortcut to get nccl status"""
return self._status.nccl
@property
def fp16(self):
"""Shortcut to get FP16 status"""
return self._status.fp16
@property
def is_apex(self):
"""Returns if APEX is activated"""
return self._status.is_fp16_apex
@property
def is_amp(self):
"""Returns if AMP is activated"""
return self._status.is_fp16_amp
@property
def distributed(self):
"""Shortcut to distributed status"""
return self._status.distributed
@property
def is_ddp(self):
"""Returns if DDP is activated"""
return self._status.is_distributed_ddp
@property
def is_horovod(self):
"""Returns if Horovod is activated"""
return self._status.is_distributed_horovod
@property
def is_deepspeed(self):
"""Returns if Deepspeed is acticated"""
return self._status.is_distributed_deepspeed
@property
def oss(self):
"""Returns if Fairscale optimizer state sharding status"""
return self._status.oss
@property
def sharded(self):
"""Returns if Fairscale sharded DDP status"""
return self._status.sharded
@property
def fully_sharded(self):
"""Returns if Fairscale fully sharded DDP status"""
return self._status.fully_sharded
@property
def world_size(self):
"""Shortcut to get world size"""
return self._runner.world_size
@property
def rank(self):
"""Shortcut to get rank"""
return self._runner.rank
@property
def amp_config(self):
"""Returns amp config or None based on amp state"""
return self._status.amp_config if self.is_amp else None
@property
def apex_config(self):
"""Returns apex config or None based on apex state"""
return self._status.apex_config if self.is_apex else None
@property
def ddp_config(self):
"""Returns ddp config or None based on ddp state"""
return self._status.ddp_config if self.is_ddp else None
@property
def deepspeed_config(self):
"""Returns deepspeed config or None based on deepspeed state"""
return self._status.deepspeed_config if self.is_deepspeed else None
@property
def oss_config(self):
"""Returns oss config or None based on ossstate"""
return self._status.oss_config if self.oss else None
@property
def sddp_config(self):
"""Returns sddp config or None based on sddp state"""
return self._status.sddp_config if self.sharded else None
@property
def fsdp_config(self):
"""Returns fsdp config or None based on fsdp state"""
return self._status.fsdp_config if self.fully_sharded else None
@property
def horovod_config(self):
"""Returns horovod config or None based on horovod state"""
return self._status.horovod_config if self.is_horovod else None
@property
def num_model_parameters(self):
"""Returns number of parameters that require gradients"""
return sum(p.numel() for p in self.model_access.parameters() if p.requires_grad)
@property
def ema_loss(self):
"""Returns the current rolling mean loss"""
return self._rolling_mean_loss
Instance variables
amp_config
Returns amp config or None based on amp state
apex_config
Returns apex config or None based on apex state
batch_size
Shortcut to batch size
cuda
Shortcut to get cuda status
ddp_config
Returns ddp config or None based on ddp state
deepspeed_config
Returns deepspeed config or None based on deepspeed state
distributed
Shortcut to distributed status
effective_batch_size
Shortcut to effective batch size
ema_loss
Returns the current rolling mean loss
fp16
Shortcut to get FP16 status
fp16_state_dict
Gets the fp16 state dict from various methods
fsdp_config
Returns fsdp config or None based on fsdp state
fully_sharded
Returns if Fairscale fully sharded DDP status
gpu
Shortcut to get GPU status
grad_accum
Shortcut to get grad accumulation
grad_clip
Shortcut to get grad clip
horovod_config
Returns horovod config or None based on horovod state
is_amp
Returns if AMP is activated
is_apex
Returns if APEX is activated
is_ddp
Returns if DDP is activated
is_deepspeed
Returns if Deepspeed is acticated
is_horovod
Returns if Horovod is activated
loss_access
Gets loss tensor(s)
model_access
Interface for model access due to the different types between the DP, DDP, and SDDP implementations
nccl
Shortcut to get nccl status
num_model_parameters
Returns number of parameters that require gradients
optimizer
Gets the optimizer
oss
Returns if Fairscale optimizer state sharding status
oss_config
Returns oss config or None based on ossstate
rank
Shortcut to get rank
scaler
Gets the current scaler object
sddp_config
Returns sddp config or None based on sddp state
sharded
Returns if Fairscale sharded DDP status
status
Gets the StokeStatus object
step_loss
Gets the last step loss synced across device(s) (unscaled)
world_size
Shortcut to get world size
Methods
DataLoader
def DataLoader(
self,
dataset: torch.utils.data.dataset.Dataset[+T_co],
shuffle: bool = False,
sampler: Union[torch.utils.data.sampler.Sampler[int], NoneType] = None,
batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence[int]], NoneType] = None,
num_workers: int = 0,
collate_fn: Callable[[List[~T]], Any] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Callable[[int], NoneType] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False
)
Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.
Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called) and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus object is available which is post init. This could be disconnected from this class but it would require the user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and never handled
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset | Dataset | dataset from which to load the data. | None |
shuffle | bool, default: False | set to True to have the data reshuffled at every epoch. |
None |
sampler | Sampler or Iterable, default: None | defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ |
|
implemented. If specified, :attr:shuffle must not be specified. |
None | ||
batch_sampler | Sampler or Iterable, default: None: | like :attr:sampler , but returns a batch of indices at a time. Mutually exclusive with |
|
:attr:batch_size , :attr:shuffle , :attr:sampler , and :attr:drop_last . |
None | ||
num_workers | int, default: 0 | how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. |
None |
collate_fn | callable, optional: | merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a | |
map-style dataset. | None | ||
pin_memory | bool, default: False: | If True , the data loader will copy Tensors into CUDA pinned memory before returning them. If your |
|
data elements are a custom type, or your :attr:collate_fn returns a batch that is a custom type, |
|||
see the example below. | None | ||
drop_last | bool, default: False | set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. |
|
If False and the size of dataset is not divisible by the batch size, then the last batch |
|||
will be smaller. | None | ||
timeout | numeric, default: 0 | if positive, the timeout value for collecting a batch from workers. Should always be non-negative. | None |
worker_init_fn | callable, default: None | If not None , this will be called on each worker subprocess with the worker id |
|
(an int in [0, num_workers - 1] ) as input, after seeding and before data loading. |
None | ||
prefetch_factor | int, default: 2 | Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers |
|
samples prefetched across all workers. | None | ||
persistent_workers | bool, default: False | If True , the data loader will not shutdown the worker processes after a dataset has been |
|
consumed once. This allows to maintain the workers Dataset instances alive. |
None |
Returns:
Type | Description |
---|---|
StokeDataLoader | wrapped torch.utils.data.DataLoader object |
??? example "View Source" def DataLoader(
self,
dataset: Dataset[T_co],
shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0,
collate_fn: _collate_fn_t = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
):
"""Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.
Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called)
and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus
object is available which is post init. This could be disconnected from this class but it would require the
user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and
never handled
Parameters
----------
dataset: Dataset
dataset from which to load the data.
shuffle: bool, default: False
set to ``True`` to have the data reshuffled at every epoch.
sampler: Sampler or Iterable, default: None
defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler: Sampler or Iterable, default: None:
like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers: int, default: 0
how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.
collate_fn: callable, optional:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory: bool, default: False:
If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your
data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last: bool, default: False
set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
will be smaller.
timeout: numeric, default: 0
if positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn: callable, default: None
If not ``None``, this will be called on each worker subprocess with the worker id
(an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.
prefetch_factor: int, default: 2
Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers
samples prefetched across all workers.
persistent_workers: bool, default: False
If ``True``, the data loader will not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers `Dataset` instances alive.
Returns
-------
StokeDataLoader
wrapped torch.utils.data.DataLoader object
"""
# Check if forkserver is available for horovod and use
if (
num_workers > 0
and hasattr(torch.multiprocessing, "_supports_context")
and torch.multiprocessing._supports_context
and "forkserver" in torch.multiprocessing.get_all_start_methods()
and self.is_horovod
):
multiprocessing_context = "forkserver"
if self._verbose and self.gpu:
print(f"Automatically handling moving model input data to GPU(s)...")
# Forward the already known options from the Stoke status
return StokeDataLoader(
gpu=self.gpu,
fp16=self.fp16,
batch_size=self.batch_size,
dataset=dataset,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
backward
def backward(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
)
Wrapped backwards call
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss | Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] | Callable loss function(s) | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def backward(
self, loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
):
"""Wrapped backwards call
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
Callable loss function(s)
Returns
-------
None
"""
# Increment the grad counter
self._grad_accum_counter += 1
# Set the context based on the counter
dist_cm = (
nullcontext()
if self._check_accum()
else self._runner.grad_accum_context(self._model)
)
with dist_cm:
self._runner.backward_call(
loss=loss, model=self.model_access, optimizer=self._optimizer
)
# Increment the number of total calls to backward (each backward to a loss is only considered 1)
self._backward_steps += 1
barrier
def barrier(
self
)
Calls the underlying distributed barrier if available
??? example "View Source" def barrier(self):
"""Calls the underlying distributed barrier if available"""
self._runner.barrier()
detach_and_sync_loss
def detach_and_sync_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
device=None
)
Shorthand method to detach and sync loss
Maps to the runner function of the same name
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss | Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] | current loss(es) | None |
device | default: None | device to sync across | None |
Returns:
Type | Description |
---|---|
loss that is synced across devices and all_reduced w/ SUM | None |
??? example "View Source" def detach_and_sync_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
device=None,
):
"""Shorthand method to detach and sync loss
Maps to the runner function of the same name
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
current loss(es)
device: default: None
device to sync across
Returns
-------
loss that is synced across devices and all_reduced w/ SUM
"""
return self._runner.detach_and_sync_loss(loss=loss, device=device)
dump_model_parameter_info
def dump_model_parameter_info(
self
)
Dumps all parameter information for named parameters (shape, device, dtype)
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def dump_model_parameter_info(self):
"""Dumps all parameter information for named parameters (shape, device, dtype)
Returns
-------
None
"""
self.print("Dumping all model parameter information to stdout....")
for name, param in self.model_access.named_parameters():
if param.requires_grad:
self.print(
f"Name: {name}, Shape: {param.shape}, "
f"Device: {param.device}, dtype: {param.dtype}"
)
load
def load(
self,
path: str,
tag: str,
strict: bool = True
)
Loads a model checkpoint using the correct backend interface
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | path to directory that the model checkpoint was saved (prefer absolute paths over relative paths) | None |
tag | str | full tag name the model checkpoint was saved as | None |
strict | bool | ignore non-matching keys | None |
Returns:
Type | Description |
---|---|
dict, default: None | a dictionary of any custom fields the user passed to the save function |
??? example "View Source" def load(self, path: str, tag: str, strict: bool = True):
"""Loads a model checkpoint using the correct backend interface
Parameters
----------
path: str
path to directory that the model checkpoint was saved (prefer absolute paths over relative paths)
tag: str
full tag name the model checkpoint was saved as
strict: bool
ignore non-matching keys
Returns
-------
extras: dict, default: None
a dictionary of any custom fields the user passed to the save function
"""
# TODO: How to deal with mapping between backends? e.g. FP16 model back to FP32? Or multi-gpu to CPU?
backward_step, grad_accum_step, optimizer_step, extras = self._runner.load(
model=self._model if self.fully_sharded else self.model_access,
optimizer=self.optimizer,
gpu=self.gpu,
path=path,
tag=tag,
scaler_dict_fn=self._load_fp16_state_dict_fn(),
strict=strict,
)
# Reset values based on what was in the load dict
self._backward_steps = backward_step
self._grad_accum_counter = grad_accum_step
self._optimizer_steps = optimizer_step
self.print(f"Successfully loaded model checkpoint from {path}/{tag}")
# Return the extras dict
return extras
loss
def loss(
self,
*args,
**kwargs
)
Wrapped callable loss function call
Handles internal logic of aggregating up the losses for single and multiple losses
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args | list or tuple | Additional arguments should be passed as keyword arguments | None |
**kwargs | dict | Extra arguments passed to the loss function call(s) | None |
Returns:
Type | Description |
---|---|
outputs of callable loss function(s) | None |
??? example "View Source" def loss(self, args, *kwargs):
"""Wrapped callable loss function call
Handles internal logic of aggregating up the losses for single and multiple losses
Parameters
----------
*args: list or tuple
Additional arguments should be passed as keyword arguments
**kwargs: dict, optional
Extra arguments passed to the loss function call(s)
Returns
-------
outputs of callable loss function(s)
"""
# TODO: WIP Handle multiple losses. Should support list/tuple of losses. Check non base PyTorch
with self._runner.loss_context:
if isinstance(self._loss, (list, tuple)):
loss = type(self._loss)(val(*args, **kwargs) for val in self._loss)
sync_loss = [self.detach_and_sync_loss(val) for val in loss]
self._last_step_loss = type(self._loss)(
val for idx, val in enumerate(sync_loss)
)
self._agg_loss = type(self._loss)(
self._agg_loss[idx] + val for idx, val in enumerate(sync_loss)
)
self._handle_ema_loss(loss=sync_loss)
if self.grad_accum > 1 and self.model_access.training:
loss = type(loss)(val / self.grad_accum for val in loss)
else:
loss = self._loss(*args, **kwargs)
sync_loss = self.detach_and_sync_loss(loss)
self._last_step_loss = sync_loss
self._agg_loss += sync_loss
self._handle_ema_loss(loss=sync_loss)
# Handle grad accumulation by dividing by the accumulation steps
if self.grad_accum > 1 and self.model_access.training:
loss = loss / self.grad_accum
return loss
model
def model(
self,
*args,
**kwargs
)
Wrapped model forward call
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args | list or tuple | Additional arguments should be passed as keyword arguments | None |
**kwargs | dict | Extra arguments passed to the model forward call | None |
Returns:
Type | Description |
---|---|
model forward output | None |
??? example "View Source" def model(self, args, *kwargs):
"""Wrapped model forward call
Parameters
----------
*args: list or tuple
Additional arguments should be passed as keyword arguments
**kwargs: dict, optional
Extra arguments passed to the model forward call
Returns
-------
model forward output
"""
with self._runner.model_context:
return self._model(*args, **kwargs)
# return self.model_access(*args, **kwargs)
def print(
self,
msg: Union[str, List[str]],
single_line: bool = False
)
Wraps the runners print device and forces print on the _info_rank attribute(s)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
msg | str | message to print | None |
single_line | bool, default: False | if iterable print all on one line space and comma separated | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def print(self, msg: Union[str, List[str]], single_line: bool = False):
"""Wraps the runners print device and forces print on the _info_rank attribute(s)
Parameters
----------
msg: str
message to print
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
self._runner.print_device(
msg=msg, rank=self._info_rank, single_line=single_line
)
print_ema_loss
def print_ema_loss(
self,
prepend_msg: str = 'Current EMA Loss',
single_line: bool = False
)
Prints the current ema loss synced across all devices
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prepend_msg | str, default: "Current EMA Loss" | message prepend to print | None |
single_line | bool, default: False | if iterable print all on one line space and comma separated | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def print_ema_loss(
self, prepend_msg: str = "Current EMA Loss", single_line: bool = False
):
"""Prints the current ema loss synced across all devices
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
prepend_msg: str, default: "Current EMA Loss"
message prepend to print
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
if isinstance(self._rolling_mean_loss, (list, tuple)):
print_vals = [
f"{prepend_msg} {idx}: {val:.3f}"
for idx, val in enumerate(self._rolling_mean_loss)
]
self.print(print_vals, single_line=single_line)
else:
self.print(f"{prepend_msg}: {self._rolling_mean_loss:.3f}")
print_mean_accumulated_synced_loss
def print_mean_accumulated_synced_loss(
self,
prepend_msg: str = 'Mean Accumulated & Synced Loss',
pre_backwards: bool = True,
single_line: bool = False
)
Prints the mean accumulated and device synced loss only after the grad accumulation step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prepend_msg | str, default: "Mean Accumulated & Synced Loss" | message prepend to print | None |
pre_backwards | bool, default: True | if being called pre backward step | None |
single_line | bool, default: False | if iterable print all on one line space and comma separated | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def print_mean_accumulated_synced_loss(
self,
prepend_msg: str = "Mean Accumulated & Synced Loss",
pre_backwards: bool = True,
single_line: bool = False,
):
"""Prints the mean accumulated and device synced loss only after the grad accumulation step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
prepend_msg: str, default: "Mean Accumulated & Synced Loss"
message prepend to print
pre_backwards: bool, default: True
if being called pre backward step
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
check_fn = self._check_pre_accum if pre_backwards else self._check_accum
if check_fn():
if isinstance(self._agg_loss, (list, tuple)):
print_vals = self._scale_agg_loss()
self.print(print_vals, single_line=single_line)
else:
self.print(f"{prepend_msg}: {self._scale_agg_loss():.3f}")
print_num_model_parameters
def print_num_model_parameters(
self,
normalize: stoke.utils.ParamNormalize = <ParamNormalize.MILLION: 1000000.0>
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
normalize | ParamNormalize, default: ParamNormalize.MILLION | ParamNormalize choice for pretty print normalizing | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def print_num_model_parameters(
self, normalize: ParamNormalize = ParamNormalize.MILLION
):
"""
Parameters
----------
normalize: ParamNormalize, default: ParamNormalize.MILLION
ParamNormalize choice for pretty print normalizing
Returns
-------
None
"""
self.print(
f"Total Trainable Model Parameters: "
f"{(self.num_model_parameters / normalize.value):.3f} {normalize.name}"
)
print_on_devices
def print_on_devices(
self,
msg: Union[str, List[str]],
rank: Union[int, List[int], NoneType] = 0
)
Wraps runner print interface for shorter semantics
Parameters:
Name | Type | Description | Default |
---|---|---|---|
msg | str | message to print | None |
rank | Union[int, List[int]], default: 0 | which ranks to print on | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def print_on_devices(
self, msg: Union[str, List[str]], rank: Optional[Union[int, List[int]]] = 0
):
"""Wraps runner print interface for shorter semantics
Parameters
----------
msg: str
message to print
rank: Union[int, List[int]], default: 0
which ranks to print on
Returns
-------
None
"""
self._runner.print_device(msg=msg, rank=rank)
print_synced_loss
def print_synced_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
prepend_msg: str = 'Step Synced Loss',
device=None,
single_line: bool = False
)
Prints a device synced loss at a single step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss | Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] | current loss(es) on the device | None |
prepend_msg | str, default: "Step Synced Loss" | message prepend to print | None |
device | default: None | specify the device to place the synced loss on (defaults to same device) | same |
single_line | bool, default: False | if iterable print all on one line space and comma separated | None |
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def print_synced_loss(
self,
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
prepend_msg: str = "Step Synced Loss",
device=None,
single_line: bool = False,
):
"""Prints a device synced loss at a single step
Handles single or multiple losses. Prints only on devices specified by self._info_rank
Parameters
----------
loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
current loss(es) on the device
prepend_msg: str, default: "Step Synced Loss"
message prepend to print
device: default: None
specify the device to place the synced loss on (defaults to same device)
single_line: bool, default: False
if iterable print all on one line space and comma separated
Returns
-------
None
"""
printable_loss = self.detach_and_sync_loss(loss, device)
if isinstance(printable_loss, (list, tuple)):
print_vals = [
f"{prepend_msg} {idx}: {val * self.grad_accum:.3f}"
for idx, val in enumerate(printable_loss)
]
self.print(print_vals, single_line=single_line)
else:
self.print(msg=f"{prepend_msg}: {printable_loss * self.grad_accum:.3f}")
reset
def reset(
self
)
Public method for resetting the underlying stoke state
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def reset(self):
"""Public method for resetting the underlying stoke state
Returns
-------
None
"""
self._reset()
reset_ema
def reset_ema(
self
)
Used to reset the current state of the rolling mean loss
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def reset_ema(self):
"""Used to reset the current state of the rolling mean loss
Returns
-------
None
"""
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
reset_tracking
def reset_tracking(
self
)
Public method for resetting all underlying stoke tracked variables
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def reset_tracking(self):
"""Public method for resetting all underlying stoke tracked variables
Returns
-------
None
"""
# Create some tracking vars
self._grad_accum_counter = 0
self._optimizer_steps = 0
self._backward_steps = 0
self._last_step_loss = self._set_loss_to_zero()
self._agg_loss = self._set_loss_to_zero()
self._rolling_mean_loss = self._set_loss_to_zero()
self._rolling_loss_steps = 0
save
def save(
self,
path: str,
name: str = UUID('1bec68f4-7df7-48d2-a526-14685e92f54f'),
extension: str = 'pt',
create_directory: bool = True,
extras: Union[dict, NoneType] = None
)
Saves a model checkpoint using the correct backend interface
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | path to directory to save the model checkpoint (prefer absolute paths over relative paths) | None |
name | str, default: uuid4() | name used to save checkpoint file | None |
extension | str, default: '.pt' | extension used to save PyTorch model checkpoint | None |
create_directory | bool, default: True | flag to create the directory path if it doesn't exist | None |
extras | dict, default: None | a dictionary of any extra things to save | None |
Returns:
Type | Description |
---|---|
str | path to directory that the model checkpoint was saved |
??? example "View Source" def save(
self,
path: str,
name: str = uuid4(),
extension: str = "pt",
create_directory: bool = True,
extras: Optional[dict] = None,
):
"""Saves a model checkpoint using the correct backend interface
Parameters
----------
path: str
path to directory to save the model checkpoint (prefer absolute paths over relative paths)
name: str, default: uuid4()
name used to save checkpoint file
extension: str, default: '.pt'
extension used to save PyTorch model checkpoint
create_directory: bool, default: True
flag to create the directory path if it doesn't exist
extras: dict, default: None
a dictionary of any extra things to save
Returns
-------
path: str
path to directory that the model checkpoint was saved
tag: str
full tag name the model checkpoint was saved as
"""
out_path, tag = self._runner.save(
model=self._model if self.fully_sharded else self.model_access,
optimizer=self.optimizer,
path=path,
backward_step=self._backward_steps,
grad_accum_step=self._grad_accum_counter,
optimizer_step=self._optimizer_steps,
name=name,
scaler_dict=self.fp16_state_dict,
extension=extension,
create_directory=create_directory,
extras=extras,
status=self.status.status,
)
self.print(f"Successfully saved model checkpoint to {out_path}/{tag}")
return out_path, tag
step
def step(
self
)
Wrapped step call
Handles grad clipping internally
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def step(self):
"""Wrapped step call
Handles grad clipping internally
Returns
-------
None
"""
# Step the optimizer only if the modulo is zero
if self._check_accum():
if self._verbose and self.grad_accum > 0:
self.print(f"Gradient Accumulation Steps: {self.grad_accum}")
# Clip if needed
if self.grad_clip is not None:
self._runner.clip_grad(
self.grad_clip,
self._model if self.fully_sharded else self.model_access,
self._optimizer,
oss=self.oss,
horovod=self.is_horovod,
deepspeed=self.is_deepspeed,
fsdp=self.fully_sharded,
)
# Handle the optimizer step
step_cm = (
self._runner.step_context(self._optimizer)
if self.grad_clip is not None
else nullcontext()
)
with step_cm:
self._runner.step_call(
model=self.model_access, optimizer=self._optimizer
)
# Reset for the accumulated step
self._reset()
# Increment the number of step calls to the optimizer
self._optimizer_steps += 1
# if deepspeed we need to step everytime as it handles the grad accumulation internally
elif self.is_deepspeed:
# Handle the optimizer step
step_cm = (
self._runner.step_context(self._optimizer)
if self.grad_clip is not None
else nullcontext()
)
with step_cm:
self._runner.step_call(
model=self.model_access, optimizer=self._optimizer
)
zero_grads
def zero_grads(
self
)
Zeros the optimizer grads depending on the optimizer type
Returns:
Type | Description |
---|---|
None | None |
??? example "View Source" def zero_grads(self):
"""Zeros the optimizer grads depending on the optimizer type
Returns
-------
None
"""
zero_optimizer_grads(
optimizer=self._optimizer, apex=self.is_apex, horovod=self.is_horovod
)