Skip to content

Module stoke

Stoke is a lightweight wrapper for PyTorch that provides a simple unified interface for context switching

Please refer to the documentation provided in the README.md

??? example "View Source" # -- coding: utf-8 --

    # Copyright FMR LLC <opensource@fidelity.com>

    # SPDX-License-Identifier: Apache-2.0



    """Stoke is a lightweight wrapper for PyTorch that provides a simple unified interface for context switching



    Please refer to the documentation provided in the README.md

    """



    from .configs import *

    from .data import BucketedDistributedSampler

    from .status import DistributedOptions, FP16Options

    from .stoke import Stoke

    from .utils import ParamNormalize



    __all__ = [

        "Stoke",

        "ParamNormalize",

        "FP16Options",

        "DistributedOptions",

        "StokeOptimizer",

        "ClipGradNormConfig",

        "ClipGradConfig",

        "FairscaleOSSConfig",

        "FairscaleSDDPConfig",

        "FairscaleFSDPConfig",

        "HorovodConfig",

        "ApexConfig",

        "DeepspeedConfig",

        "DDPConfig",

        "AMPConfig",

        "DeepspeedAIOConfig",

        "DeepspeedActivationCheckpointingConfig",

        "DeepspeedFlopsConfig",

        "DeepspeedFP16Config",

        "DeepspeedPLDConfig",

        "DeepspeedOffloadOptimizerConfig",

        "DeepspeedOffloadParamConfig",

        "DeepspeedTensorboardConfig",

        "DeepspeedZeROConfig",

        "BucketedDistributedSampler",

    ]



    from ._version import get_versions



    __version__ = get_versions()["version"]



    del get_versions

Sub-modules

Classes

AMPConfig

class AMPConfig(
    backoff_factor: float = 0.5,
    growth_factor: float = 2.0,
    growth_interval: int = 2000,
    init_scale: float = 65536.0
)

Attributes

Name Type Description Default
backoff_factor float, default: 0.5 Factor by which the scale is multiplied during update if inf/NaN gradients occur in an iteration None
growth_factor float, default: 2.0 Factor by which the scale is multiplied during update if no inf/NaN gradients occur for growth_interval consecutive iterations. None
growth_interval int, default: 2000 Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by
growth_factor None
init_scale float, default: 2.**16 Initial scale factor None

??? example "View Source" class AMPConfig:

        """PyTorch AMP configuration class



        Attributes

        ----------

        backoff_factor : float, default: 0.5

            Factor by which the scale is multiplied during update if inf/NaN gradients occur in an iteration

        growth_factor : float, default: 2.0

            Factor by which the scale is multiplied during update if no inf/NaN gradients occur for growth_interval consecutive iterations.

        growth_interval : int, default: 2000

            Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by

            growth_factor

        init_scale : float, default: 2.**16

            Initial scale factor



        """



        backoff_factor: float = 0.5

        growth_factor: float = 2.0

        growth_interval: int = 2000

        init_scale: float = 2.0 ** 16

ApexConfig

class ApexConfig(
    cast_model_outputs: Union[torch.dtype, NoneType] = None,
    convert_to_sync_batch_norm: bool = False,
    max_loss_scale: float = 16777216.0,
    min_loss_scale: Union[float, NoneType] = None,
    scaler_per_loss: bool = False,
    verbosity: int = 0
)

Attributes

Name Type Description Default
cast_model_outputs Optional[torch.dtype], default: None Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level None
convert_to_sync_batch_norm bool, default: False Automatically convert all batch norm calls to apex.parallel.SyncBatchNorm calls
https://nvidia.github.io/apex/parallel.html#apex.parallel.SyncBatchNorm None
max_loss_scale float, default: 2.**24 Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling None
min_loss_scale Optional[float], default: None Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None
means that no floor is imposed value
scaler_per_loss bool, default: False Option to impose a scaler for each loss instead of a global scaler None
verbosity int, default: 0 Set to 0 to suppress Amp-related output None

??? example "View Source" class ApexConfig:

        """Nvidia APEX configuration class



        Attributes

        ----------

        cast_model_outputs: Optional[torch.dtype], default: None

            Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level

        convert_to_sync_batch_norm: bool, default: False

            Automatically convert all batch norm calls to apex.parallel.SyncBatchNorm calls

            https://nvidia.github.io/apex/parallel.html#apex.parallel.SyncBatchNorm

        max_loss_scale: float, default: 2.**24

            Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling

        min_loss_scale: Optional[float], default: None

            Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None

            means that no floor is imposed

        scaler_per_loss: bool, default: False

            Option to impose a scaler for each loss instead of a global scaler

        verbosity: int, default: 0

            Set to 0 to suppress Amp-related output



        """



        cast_model_outputs: Optional[torch.dtype] = None

        convert_to_sync_batch_norm: bool = False

        max_loss_scale: float = 2.0 ** 24

        min_loss_scale: Optional[float] = None

        scaler_per_loss: bool = False

        verbosity: int = 0

BucketedDistributedSampler

class BucketedDistributedSampler(
    dataset: torch.utils.data.dataset.Dataset,
    buckets: int,
    batch_size: int,
    sorted_idx: List,
    backend: stoke.status.DistributedOptions,
    allow_bucket_overlap: bool = False,
    num_replicas: Union[int, NoneType] = None,
    rank: Union[int, NoneType] = None,
    shuffle: bool = True,
    seed: int = 0,
    drop_last: bool = False,
    info_rank: int = 0
)

Attributes

Name Type Description Default
num_replicas int, default: None number of replicas None
rank int, default: None current device rank None
epoch int current training epoch None
drop_last bool, default: False whether to drop last set of samples that don't fit into a batch None
shuffle bool, default: True flag to shuffle dataset None
seed int, default: 0 seed to use for generators None
buckets int number of buckets to break the dataset into None
sorted_n_samples list sorted list of samples by the characteristic to bucket by (e.g. seq len) None
batch_size int batch size that will be used (needed to make sure slices are correct) None
allow_bucket_overlap bool, default: False allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into
an un-bucketed batch None
slice_size int computed from batch size and number of replicas None
num_samples_per_bucket int computed value that represents the number of samples in a single bucket None
num_slices_per_bucket int computed value that represents the number of slices available in a bucket None
bucket_idx list computed value that make a contiguous list of indices in each bucket None
rounded_num_samples_per_bucket int computed value post round for number of samples in a single bucket None
rounded_num_samples_per_replica int computed value post round for number of slices available in a bucket None

??? example "View Source" class BucketedDistributedSampler(Sampler[T_co]):

        """Sampler that buckets samples by sorted_idx and then randomly samples from a specific bucket to prevent excess

        padding leading to wasted computation



        Borrowing heavily from the base DistributedSampler

        https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler



        Attributes

        ----------

        num_replicas: int, default: None

            number of replicas

        rank: int, default: None

            current device rank

        epoch: int

            current training epoch

        drop_last: bool, default: False

            whether to drop last set of samples that don't fit into a batch

        shuffle: bool, default: True

            flag to shuffle dataset

        seed: int, default: 0

            seed to use for generators

        buckets: int

            number of buckets to break the dataset into

        sorted_n_samples: list

            sorted list of samples by the characteristic to bucket by (e.g. seq len)

        batch_size: int

            batch size that will be used (needed to make sure slices are correct)

        allow_bucket_overlap: bool, default: False

            allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into

            an un-bucketed batch

        slice_size: int

            computed from batch size and number of replicas

        num_samples_per_bucket: int

            computed value that represents the number of samples in a single bucket

        num_slices_per_bucket: int

            computed value that represents the number of slices available in a bucket

        bucket_idx: list

            computed value that make a contiguous list of indices in each bucket

        rounded_num_samples_per_bucket: int

            computed value post round for number of samples in a single bucket

        rounded_num_samples_per_replica: int

            computed value post round for number of slices available in a bucket



        """



        def __init__(

            self,

            dataset: Dataset,

            buckets: int,

            batch_size: int,

            sorted_idx: List,

            backend: DistributedOptions,

            allow_bucket_overlap: bool = False,

            num_replicas: Optional[int] = None,

            rank: Optional[int] = None,

            shuffle: bool = True,

            seed: int = 0,

            drop_last: bool = False,

            info_rank: int = 0,

        ) -> None:

            """Init for BucketedDistributedSampler



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            buckets: int

                number of buckets to break the dataset into

            batch_size: int

                batch size that will be used (needed to make sure slices are correct)

            sorted_idx: list

                sorted list of samples by the characteristic to bucket by (e.g. seq le

            backend: DistributedOptions

                which backend is being used (as rank, world size, etc. need to be used)

            allow_bucket_overlap: bool, default: False

                allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into

                an un-bucketed batch

            num_replicas: int, default: None

                number of replicas

            rank: int, default: None

                current device rank

            shuffle: bool, default: True

                flag to shuffle dataset

            seed: int, default: 0

                seed to use for generators

            drop_last: bool, default: False

                whether to drop last set of samples that don't fit into a

            info_rank: int, default: 0

                which device to print information on



            """

            # If the backend isnt DDP there needs to be an additional import

            num_replicas, rank = self._conditional_distributed(

                backend=backend, num_replicas=num_replicas, rank=rank

            )

            self.num_replicas = num_replicas

            self.rank = rank

            self.epoch = 0

            self.drop_last = drop_last

            self.shuffle = shuffle

            self.seed = seed

            self.buckets = buckets

            self.sorted_n_samples = sorted_idx

            # Batch size is needed here so a contiguous iter of buckets can be formed

            self.batch_size = batch_size

            # This is a flag to batch up the dropped samples (that would be 'wasted') if drop_last is flagged

            self.allow_bucket_overlap = allow_bucket_overlap

            # Calculate the size of each slice that will be indexed across the replicas

            self.slice_size = self.batch_size * self.num_replicas

            # Calculate the size of the buckets (rounded or not based on drop last)

            self.num_samples_per_bucket = self._get_size(

                len(dataset), self.buckets, self.drop_last

            )

            # Calculate the number of slices per bucket

            self.num_slices_per_bucket = self._get_size(

                self.num_samples_per_bucket, self.slice_size, self.drop_last

            )

            if self.num_samples_per_bucket < self.slice_size:

                raise ValueError(

                    f"Stoke -- Resulting number of slices (batch * replicas) per bucket "

                    f"({self.num_samples_per_bucket}) is less than the batch size "

                    f"({self.batch_size})"

                )

            if self.num_slices_per_bucket < 2:

                raise ValueError(

                    f"Stoke -- Number of slices per bucket {self.num_slices_per_bucket} is less than 2 "

                    f"which is not recommended"

                )

            if self.num_samples_per_bucket < 100:

                raise ValueError(

                    f"Stoke -- Number of samples per bucket {self.num_samples_per_bucket} is less than 100 "

                    f"which is not recommended as this might lead to dropping of excessive data"

                )

            # Split into buckets and turn into lists

            self.bucket_idx = [

                list(val) for val in np.array_split(self.sorted_n_samples, self.buckets)

            ]

            # Calculate the post rounded numbers

            self.rounded_num_samples_per_bucket = (

                self.slice_size * self.num_slices_per_bucket

            )

            self.rounded_num_samples_per_replica = (

                self.num_slices_per_bucket * self.batch_size * self.buckets

            )

            # Add the bucket overlap samples

            if self.allow_bucket_overlap:

                self.rounded_num_samples_per_replica += (

                    (len(dataset) - (self.rounded_num_samples_per_bucket * self.buckets))

                    // self.slice_size

                ) * self.batch_size

            if self.rank == info_rank:

                print(

                    f"Stoke -- BucketedDistributedSampler -- # Samples Per Bucket: "

                    f"{self.rounded_num_samples_per_bucket}, # of Samples Per Replica: "

                    f"{self.rounded_num_samples_per_replica}"

                )



        def _conditional_distributed(

            self,

            backend: DistributedOptions,

            num_replicas: Optional[int],

            rank: Optional[int],

        ):

            """



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used

            num_replicas: int, default: None

                total number of replicas

            rank: int, default: None

                current device rank



            Returns

            -------

            Tuple[int, int]

                num_replicas, rank

            """

            return self._check_backend(backend, num_replicas, rank)



        def _get_backend_functions(self, backend: DistributedOptions):

            """Gets backend functions if needed



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used



            Returns

            -------

            Tuple[bool, int, int]

                is_init, num_replicas, rank



            """

            if backend.value == "ddp" or backend.value == "deepspeed":

                return (

                    torch.distributed.is_initialized,

                    torch.distributed.get_world_size,

                    torch.distributed.get_rank,

                )

            else:

                return hvd.is_initialized, hvd.size, hvd.rank



        def _check_backend(

            self,

            backend: DistributedOptions,

            num_replicas: Optional[int],

            rank: Optional[int],

        ):

            """Checks the backend for correct device info



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used

            num_replicas: int, default: None

                total number of replicas

            rank: int, default: None

                current device rank



            Returns

            -------

            Tuple[int, int]

                num_replicas, rank



            """

            if num_replicas is None or rank is None:

                is_avail, get_world_size, get_rank = self._get_backend_functions(

                    backend=backend

                )

            if num_replicas is None:

                if not is_avail():

                    raise RuntimeError(

                        "Requires distributed package (torch.dist or hvd) to be available"

                    )

                num_replicas = get_world_size()

            if rank is None:

                if not is_avail():

                    raise RuntimeError(

                        "Requires distributed package (torch.dist or hvd) to be available"

                    )

                rank = get_rank()

            return num_replicas, rank



        @staticmethod

        def _get_size(data_len: int, split_var: int, drop_last: bool = False):

            """Gets the size of a split



            Parameters

            ----------

            data_len: int

                current dataset length

            split_var: int

                how many to split into

            drop_last: bool, default: False

                drop last hanging samples if not batch_size



            Returns

            -------

            num_samples: int



            """

            if drop_last:

                num_samples = data_len // split_var

            else:

                num_samples = ceil(data_len / split_var)

            return num_samples



        def __iter__(self) -> Iterator[T_co]:

            """Handles assembling the batches from a bucketed perspective



            Shuffle bucket order->Pad if necessary->Slice across replicas->Possibly batch up residuals->shuffle bucketed

            batches->Unroll into list->Make iter



            Returns

            -------

            Iterator[T_co]



            """

            # Shuffle the bucketed idx

            if self.shuffle:

                # deterministically shuffle based on epoch and seed

                g = torch.Generator()

                g.manual_seed(self.seed + self.epoch)

                # Permute each bucket

                indices = [

                    [val[idx] for idx in torch.randperm(len(val), generator=g).tolist()]

                    for val in self.bucket_idx

                ]

            else:

                indices = self.bucket_idx

            # Iterate over the buckets

            for idx, val in enumerate(indices):

                # If this is true we need to handle padding

                if (self.num_slices_per_bucket * self.slice_size) > len(val):

                    split_val = self._handle_padding(val)

                    indices[idx] = list(itertools.chain(*split_val))

                    assert len(indices[idx]) == self.rounded_num_samples_per_bucket

            # Now slice across replicas

            final_indices = []

            for val in indices:

                for idx in range(self.num_slices_per_bucket):

                    replica_slice = val[

                        (idx * self.slice_size) : ((idx + 1) * self.slice_size)

                    ][self.rank : self.slice_size : self.num_replicas]

                    final_indices.append(replica_slice)

            # If bucket overlap is allowed then we just batch up the residual indices

            if self.drop_last and self.allow_bucket_overlap:

                residual_idx = list(

                    itertools.chain(

                        *[val[self.rounded_num_samples_per_bucket :] for val in indices]

                    )

                )

                if len(residual_idx) > self.slice_size:

                    # Cut by slices then by replicas

                    residual_idx = [

                        residual_idx[

                            (idx * self.slice_size) : ((idx + 1) * self.slice_size)

                        ][self.rank : self.slice_size : self.num_replicas]

                        for idx in range(len(residual_idx) // self.slice_size)

                    ]

                    # Append to the final indices

                    final_indices.extend(residual_idx)

            # Shuffle the bucketed batches

            if self.shuffle:

                # deterministically shuffle based on epoch and seed

                g = torch.Generator()

                g.manual_seed(self.seed + self.epoch)

                # Permute the bucket order

                final_indices = [

                    final_indices[val]

                    for val in torch.randperm(len(final_indices), generator=g)

                ]

            # Unroll into a single list

            final_indices = list(itertools.chain(*final_indices))

            assert len(final_indices) == self.rounded_num_samples_per_replica

            return iter(final_indices)



        def _handle_padding(self, idx_list: List):

            """Handles padding out if a batch is short



            Parameters

            ----------

            idx_list: List

                list of indices



            Returns

            -------

            split_val: List

                list with correctly padded sizes



            """

            split_val = []

            for idx in range(self.num_slices_per_bucket):

                if idx == (self.num_slices_per_bucket - 1):

                    # Get the short batch

                    short_batch = idx_list[(idx * self.slice_size) :]

                    # Short batch replica slice sizes

                    short_len = [

                        self.batch_size - len(list(val))

                        for val in np.array_split(short_batch, self.num_replicas)

                    ]

                    # Pop the necessary values from the entire bucket

                    pad_values = [

                        idx_list[s_idx : (self.num_replicas * s_len) : self.num_replicas]

                        for s_idx, s_len in enumerate(short_len)

                    ]

                    # If not a consistent list then we need to reorder so that the step size alignment slicing

                    # of the replicas works

                    if len(set(short_len)) != 1:

                        # here we need to find the first larger idx and reorder

                        first_idx = short_len.index(max(set(short_len)))

                        # Reorder

                        pad_values = pad_values[first_idx:] + pad_values[0:first_idx]

                    extended_batch = short_batch + [

                        pad

                        for pad in list(

                            itertools.chain(*itertools.zip_longest(*pad_values))

                        )

                        if pad is not None

                    ]

                    split_val.append(extended_batch)

                else:

                    split_val.append(

                        idx_list[(idx * self.slice_size) : ((idx + 1) * self.slice_size)]

                    )

            return split_val



        def __len__(self) -> int:

            return self.rounded_num_samples_per_replica



        def set_epoch(self, epoch: int) -> None:

            """Sets the epoch for this sampler.



            When :attr:`shuffle=True`, this ensures all replicas

            use a different random ordering for each epoch. Otherwise, the next iteration of this

            sampler will yield the same ordering.



            Parameters

            ----------

            epoch: int

                Epoch number



            """

            self.epoch = epoch

Ancestors (in MRO)

  • torch.utils.data.sampler.Sampler
  • typing.Generic

Methods

set_epoch

def set_epoch(
    self,
    epoch: int
) -> None

Sets the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int Epoch number None

??? example "View Source" def set_epoch(self, epoch: int) -> None:

            """Sets the epoch for this sampler.



            When :attr:`shuffle=True`, this ensures all replicas

            use a different random ordering for each epoch. Otherwise, the next iteration of this

            sampler will yield the same ordering.



            Parameters

            ----------

            epoch: int

                Epoch number



            """

            self.epoch = epoch

ClipGradConfig

class ClipGradConfig(
    clip_value: float
)

Attributes

Name Type Description Default
clip_value float maximum allowed absolute value of the gradients [-clip_value, clip_value] None

??? example "View Source" class ClipGradConfig:

        """Gradient clipping by value configuration class



        Attributes

        ----------

        clip_value: float

            maximum allowed absolute value of the gradients [-clip_value, clip_value]



        """



        clip_value: float

ClipGradNormConfig

class ClipGradNormConfig(
    max_norm: float,
    norm_type: float
)

Attributes

Name Type Description Default
max_norm float max norm of the gradients None
norm_type float type of the used p-norm None

??? example "View Source" class ClipGradNormConfig:

        """Gradient clipping by p-norm configuration class



        Attributes

        ----------

        max_norm: float

            max norm of the gradients

        norm_type: float

            type of the used p-norm



        """



        max_norm: float

        norm_type: float

DDPConfig

class DDPConfig(
    local_rank: Union[int, NoneType],
    auto_mpi_discovery: bool = False,
    convert_to_sync_batch_norm: bool = False,
    backend: stoke.configs.BackendOptions = 'nccl',
    broadcast_buffers: bool = True,
    bucket_cap_mb: int = 25,
    find_unused_parameters: bool = False,
    gradient_as_bucket_view: bool = False,
    init_method: str = 'env://',
    no_sync: bool = True
)

Attributes

Name Type Description Default
local_rank Optional[int] Current local rank of the device (provided here, as LOCAL_RANK env var, or parsed from --local_arg) None
auto_mpi_discovery bool, default: False if distributed environment variables are not set, attempt to discover them from MPI (using underlying deepspeed
function call) None
convert_to_sync_batch_norm bool, default: False Automatically convert all batch norm calls to torch.nn.SyncBatchNorm calls
https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html None
backend BackendOptions, default: 'nccl' Which communication backend to use None
broadcast_buffers bool, default: True Flag that enables syncing (broadcasting) buffers of the module at beginning of the forward function None
bucket_cap_mb int, default: 25 DistributedDataParallel will bucket parameters into multiple buckets so that gradient reduction of each bucket
can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MegaBytes (MB) None
find_unused_parameters bool, default: False Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s forward
function. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready
to be reduced. Note that all forward outputs that are derived from module parameters must participate in
calculating loss and later the gradient computation. If they don’t, this wrapper will hang waiting for autograd
to produce gradients for those parameters. Any outputs derived from module parameters that are otherwise unused
can be detached from the autograd graph using torch.Tensor.detach None
gradient_as_bucket_view bool, default: False When set to True, gradients will be views pointing to different offsets of allreduce communication
buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients
size. Moreover, it avoids the overhead of copying between gradients and allreduce communication buckets. When
gradients are views, detach_() cannot be called on the gradients. If hitting such errors, please fix it by
referring to the zero_grad() function in torch/optim/optimizer.py as a solution. None
init_method str, default: 'env://' URL specifying how to initialize the process group None
no_sync bool, default: True for any DDP based method (including SDDP and FSDP wrappers -- if activated gradients will be accumulated on
module variables, which will later be synchronized in the first forward-backward pass after exiting the
context. no sync might lead to higher memory usage but lower communication overhead None

??? example "View Source" class DDPConfig:

        """PyTorch DistributedDataParallel configuration class



        Attributes

        ----------

        local_rank: Optional[int]

            Current local rank of the device (provided here, as LOCAL_RANK env var, or parsed from --local_arg)

        auto_mpi_discovery: bool, default: False

            if distributed environment variables are not set, attempt to discover them from MPI (using underlying deepspeed

            function call)

        convert_to_sync_batch_norm: bool, default: False

            Automatically convert all batch norm calls to torch.nn.SyncBatchNorm calls

            https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html

        backend: BackendOptions, default: 'nccl'

            Which communication backend to use

        broadcast_buffers: bool, default: True

            Flag that enables syncing (broadcasting) buffers of the module at beginning of the forward function

        bucket_cap_mb: int, default: 25

            DistributedDataParallel will bucket parameters into multiple buckets so that gradient reduction of each bucket

            can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MegaBytes (MB)

        find_unused_parameters: bool, default: False

            Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s forward

            function. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready

            to be reduced. Note that all forward outputs that are derived from module parameters must participate in

            calculating loss and later the gradient computation. If they don’t, this wrapper will hang waiting for autograd

            to produce gradients for those parameters. Any outputs derived from module parameters that are otherwise unused

            can be detached from the autograd graph using torch.Tensor.detach

        gradient_as_bucket_view: bool, default: False

            When set to True, gradients will be views pointing to different offsets of allreduce communication

            buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients

            size. Moreover, it avoids the overhead of copying between gradients and allreduce communication buckets. When

            gradients are views, detach_() cannot be called on the gradients. If hitting such errors, please fix it by

            referring to the zero_grad() function in torch/optim/optimizer.py as a solution.

        init_method: str, default: 'env://'

            URL specifying how to initialize the process group

        no_sync: bool, default: True

            for any DDP based method (including SDDP and FSDP wrappers -- if activated gradients will be accumulated on

            module variables, which will later be synchronized in the first forward-backward pass after exiting the

            context. no sync might lead to higher memory usage but lower communication overhead



        """



        local_rank: Optional[int]

        auto_mpi_discovery: bool = False

        convert_to_sync_batch_norm: bool = False

        backend: BackendOptions = "nccl"

        broadcast_buffers: bool = True

        bucket_cap_mb: int = 25

        find_unused_parameters: bool = False

        gradient_as_bucket_view: bool = False

        init_method: str = "env://"

        no_sync: bool = True

DeepspeedAIOConfig

class DeepspeedAIOConfig(
    block_size: int = 1048576,
    ignore_unused_parameters: bool = True,
    overlap_events: bool = True,
    queue_depth: int = 8,
    single_submit: bool = False,
    thread_count: int = 1
)

Attributes

Name Type Description Default
block_size int, default: 1048576 I/O block size in bytes None
ignore_unused_parameters bool, default: True Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks.
This controls whether or not training should terminate with an error message when unused parameters are
detected. None
overlap_events bool, default: True Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. None
queue_depth int, default: 8 I/O queue depth None
single_submit bool, default: False Submit requests to storage device as multiple individual requests as opposed to one block of requests. None
thread_count int, default: 1 Intra-request parallelism for each read/write submitted by a user thread. None

??? example "View Source" class DeepspeedAIOConfig:

        """Deepspeed asynchronous I/O configuration class



        Attributes

        ----------

        block_size: int, default: 1048576

            I/O block size in bytes

        ignore_unused_parameters: bool, default: True

            Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks.

            This controls whether or not training should terminate with an error message when unused parameters are

            detected.

        overlap_events: bool, default: True

            Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests.

        queue_depth: int, default: 8

            I/O queue depth

        single_submit: bool, default: False

            Submit requests to storage device as multiple individual requests as opposed to one block of requests.

        thread_count: int, default: 1

            Intra-request parallelism for each read/write submitted by a user thread.



        """



        block_size: int = 1048576

        ignore_unused_parameters: bool = True

        overlap_events: bool = True

        queue_depth: int = 8

        single_submit: bool = False

        thread_count: int = 1

DeepspeedActivationCheckpointingConfig

class DeepspeedActivationCheckpointingConfig(
    contiguous_memory_optimization: bool = False,
    cpu_checkpointing: bool = False,
    number_checkpoints: Union[int, NoneType] = None,
    partition_activations: bool = False,
    profile: bool = False,
    synchronize_checkpoint_boundary: bool = False
)

Attributes

Name Type Description Default
contiguous_memory_optimization bool, default: False Copies partitioned activations so that they are contiguous in memory None
cpu_checkpointing bool, default: False Offloads partitioned activations to CPU if partition_activations is enabled None
number_checkpoints Optional[int], default: None Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization None
partition_activations bool, default: False Enables partition activation when used with model parallelism None
profile bool, default: False Logs the forward and backward time for each checkpoint function None
synchronize_checkpoint_boundary bool, default: False Inserts torch.cuda.synchronize() at each checkpoint boundary None

??? example "View Source" class DeepspeedActivationCheckpointingConfig:

        """Deepspeed activation checkpointing configuration class



        Attributes

        ----------

        contiguous_memory_optimization: bool, default: False

            Copies partitioned activations so that they are contiguous in memory

        cpu_checkpointing: bool, default: False

            Offloads partitioned activations to CPU if partition_activations is enabled

        number_checkpoints: Optional[int], default: None

            Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization

        partition_activations: bool, default: False

            Enables partition activation when used with model parallelism

        profile: bool, default: False

            Logs the forward and backward time for each checkpoint function

        synchronize_checkpoint_boundary: bool, default: False

            Inserts torch.cuda.synchronize() at each checkpoint boundary



        """



        contiguous_memory_optimization: bool = False

        cpu_checkpointing: bool = False

        number_checkpoints: Optional[int] = None

        partition_activations: bool = False

        profile: bool = False

        synchronize_checkpoint_boundary: bool = False

DeepspeedConfig

class DeepspeedConfig(
    activation_checkpointing: Union[stoke.configs.DeepspeedActivationCheckpointingConfig, NoneType] = DeepspeedActivationCheckpointingConfig(contiguous_memory_optimization=False, cpu_checkpointing=False, number_checkpoints=None, partition_activations=False, profile=False, synchronize_checkpoint_boundary=False),
    aio: Union[stoke.configs.DeepspeedAIOConfig, NoneType] = DeepspeedAIOConfig(block_size=1048576, ignore_unused_parameters=True, overlap_events=True, queue_depth=8, single_submit=False, thread_count=1),
    auto_mpi_discovery: bool = True,
    disable_allgather: bool = False,
    dist_backend: stoke.configs.BackendOptions = 'nccl',
    distributed_port: int = 29500,
    dump_state: bool = False,
    flops_profiler: Union[stoke.configs.DeepspeedFlopsConfig, NoneType] = None,
    fp16: Union[stoke.configs.DeepspeedFP16Config, NoneType] = None,
    fp32_allreduce: bool = False,
    gradient_predivide_factor: float = 1.0,
    init_method: str = 'env://',
    prescale_gradients: bool = False,
    progressive_layer_drop: Union[stoke.configs.DeepspeedPLDConfig, NoneType] = None,
    sparse_gradients: bool = False,
    steps_per_print: int = 10,
    tensorboard: Union[stoke.configs.DeepspeedTensorboardConfig, NoneType] = None,
    verbose: bool = True,
    wall_clock_breakdown: bool = False,
    zero_optimization: Union[stoke.configs.DeepspeedZeROConfig, NoneType] = DeepspeedZeROConfig(allgather_bucket_size=500000000, allgather_partitions=True, contiguous_gradients=False, ignore_unused_parameters=True, legacy_stage1=False, offload_optimizer=None, offload_param=None, overlap_comm=False, reduce_bucket_size=500000000, reduce_scatter=True, stage=0, stage3_max_live_parameters=1000000000, stage3_max_reuse_distance=1000000000, stage3_prefetch_bucket_size=500000000, stage3_param_persistence_threshold=1000000, stage3_gather_fp16_weights_on_model_save=False, sub_group_size=1000000000000)
)

Attributes

Name Type Description Default
activation_checkpointing Optional[DeepspeedActivationCheckpointingConfig], default: DeepspeedActivationCheckpointingConfig() Enables and configures activation checkpointing None
aio Optional[DeepspeedAIOConfig], default: DeepspeedAIOConfig() Configuring the asynchronous I/O module for offloading parameter and optimizer states to persistent
(NVMe) storage None
auto_mpi_discovery bool, default: True if distributed environment variables are not set, attempt to discover them from MPI None
disable_allgather bool, default: False Disables allgather None
dist_backend BackendOptions, default: 'nccl' Which communication backend to use None
distributed_port int, default: 29500 torch distributed backend port None
dump_state bool, default: False Print out state information of DeepSpeed object after initialization None
flops_profiler Optional[DeepspeedFlopsConfig], default: None Enables and configures the flops profiler. This would also enable wall_clock_breakdown None
fp16 Optional[DeepspeedFP16Config], default: None Enables and configures mixed precision/FP16 training that leverages NVIDIA’s Apex package None
fp32_allreduce bool, default: False During gradient averaging perform allreduce with 32 bit values None
gradient_predivide_factor float, default: 1.0 Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability
when scaling to large numbers of GPUs None
init_method str, default: 'env://' URL specifying how to initialize the process group None
prescale_gradients float, default: 1.0 Scale gradients before doing allreduce None
progressive_layer_drop Optional[DeepspeedPLDConfig], default: None Enables and configures progressive layer dropping None
sparse_gradients bool, default: False Enable sparse compression of torch.nn.Embedding gradients None
steps_per_print int, default: 10 Print train loss every N steps None
tensorboard Optional[DeepspeedTensorboardConfig], default: None Enables and configures tensorboard support None
verbose bool, default: True flag to make deepspeed engine verbose with information None
wall_clock_breakdown bool, default: False Enable timing of the latency of forward/backward/update training phases None
zero_optimization Optional[DeepspeedZeROConfig], default: DeepspeedZeROConfig() Enables and configures ZeRO memory optimizations None

??? example "View Source" class DeepspeedConfig:

        """Deepspeed configuration class



        Composed of other configuration classes related to specific functionality



        Attributes

        ----------

        activation_checkpointing: Optional[DeepspeedActivationCheckpointingConfig], default: DeepspeedActivationCheckpointingConfig()

            Enables and configures activation checkpointing

        aio: Optional[DeepspeedAIOConfig], default: DeepspeedAIOConfig()

            Configuring the asynchronous I/O module for offloading parameter and optimizer states to persistent

            (NVMe) storage

        auto_mpi_discovery: bool, default: True

            if distributed environment variables are not set, attempt to discover them from MPI

        disable_allgather: bool, default: False

            Disables allgather

        dist_backend: BackendOptions, default: 'nccl'

            Which communication backend to use

        distributed_port: int, default: 29500

            torch distributed backend port

        dump_state: bool, default: False

            Print out state information of DeepSpeed object after initialization

        flops_profiler: Optional[DeepspeedFlopsConfig], default: None

            Enables and configures the flops profiler. This would also enable wall_clock_breakdown

        fp16: Optional[DeepspeedFP16Config], default: None

            Enables and configures mixed precision/FP16 training that leverages NVIDIA’s Apex package

        fp32_allreduce: bool, default: False

            During gradient averaging perform allreduce with 32 bit values

        gradient_predivide_factor: float, default: 1.0

            Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability

            when scaling to large numbers of GPUs

        init_method: str, default: 'env://'

            URL specifying how to initialize the process group

        prescale_gradients: float, default: 1.0

            Scale gradients before doing allreduce

        progressive_layer_drop: Optional[DeepspeedPLDConfig], default: None

            Enables and configures progressive layer dropping

        sparse_gradients: bool, default: False

            Enable sparse compression of torch.nn.Embedding gradients

        steps_per_print: int, default: 10

            Print train loss every N steps

        tensorboard: Optional[DeepspeedTensorboardConfig], default: None

            Enables and configures tensorboard support

        verbose: bool, default: True

            flag to make deepspeed engine verbose with information

        wall_clock_breakdown: bool, default: False

            Enable timing of the latency of forward/backward/update training phases

        zero_optimization: Optional[DeepspeedZeROConfig], default: DeepspeedZeROConfig()

            Enables and configures ZeRO memory optimizations



        Notes

        -----

        Deepspeed does not use Apex’s AMP mode whihc allows for more flexibility in mixed precision training modes. FP16

        here is similar to AMP’s O2 mode



        """



        activation_checkpointing: Optional[

            DeepspeedActivationCheckpointingConfig

        ] = DeepspeedActivationCheckpointingConfig()

        aio: Optional[DeepspeedAIOConfig] = DeepspeedAIOConfig()

        auto_mpi_discovery: bool = True

        disable_allgather: bool = False

        dist_backend: BackendOptions = "nccl"

        distributed_port: int = 29500

        dump_state: bool = False

        flops_profiler: Optional[DeepspeedFlopsConfig] = None

        fp16: Optional[DeepspeedFP16Config] = None

        fp32_allreduce: bool = False

        gradient_predivide_factor: float = 1.0

        init_method: str = "env://"

        prescale_gradients: bool = False

        progressive_layer_drop: Optional[DeepspeedPLDConfig] = None

        sparse_gradients: bool = False

        steps_per_print: int = 10

        tensorboard: Optional[DeepspeedTensorboardConfig] = None

        verbose: bool = True

        wall_clock_breakdown: bool = False

        zero_optimization: Optional[DeepspeedZeROConfig] = DeepspeedZeROConfig()

DeepspeedFP16Config

class DeepspeedFP16Config(
    hysteresis: int = 2,
    initial_scale_power: int = 32,
    loss_scale: float = 0.0,
    loss_scale_window: int = 1000,
    min_loss_scale: int = 1000
)

Attributes

Name Type Description Default
hysteresis int, default: 2 represents the delay shift in dynamic loss scaling None
initial_scale_power int, default: 32 power of the initial dynamic loss scale value. The actual loss scale is computed as 2 ** initial_scale_power None
loss_scale float, default: 0.0 loss scaling value for FP16 training (0.0 --> dynamic scaling) None
loss_scale_window int, default: 1000 the window over which to raise/lower the dynamic loss scale value None
min_loss_scale int, default: 1000 minimum dynamic loss scale value None

??? example "View Source" class DeepspeedFP16Config:

        """Deepspeed FP16 configuration class



        Attributes

        ----------

        hysteresis: int, default: 2

            represents the delay shift in dynamic loss scaling

        initial_scale_power: int, default: 32

            power of the initial dynamic loss scale value. The actual loss scale is computed as 2 ** initial_scale_power

        loss_scale: float, default: 0.0

            loss scaling value for FP16 training (0.0 --> dynamic scaling)

        loss_scale_window: int, default: 1000

            the window over which to raise/lower the dynamic loss scale value

        min_loss_scale: int, default: 1000

            minimum dynamic loss scale value



        """



        hysteresis: int = 2

        initial_scale_power: int = 32

        loss_scale: float = 0.0

        loss_scale_window: int = 1000

        min_loss_scale: int = 1000

DeepspeedFlopsConfig

class DeepspeedFlopsConfig(
    detailed: bool = True,
    module_depth: int = -1,
    output_file: Union[str, NoneType] = None,
    profile_step: int = 1,
    top_modules: int = 1
)

Attributes

Name Type Description Default
detailed bool, default: True Whether to print the detailed model profile None
module_depth int, default: -1 The depth of the model at which to print the aggregated module information. When set to -1, it prints
information from the top module to the innermost modules (the maximum depth). None
output_file Optional[str], default: None Path to the output file. If None, the profiler prints to stdout None
profile_step int, default: 1 The global training step at which to profile. None
top_modules int, default: 1 Limits the aggregated profile output to the number of top modules specified. None

??? example "View Source" class DeepspeedFlopsConfig:

        """Deepspeed flops profiler configuration class



        Attributes

        ----------

        detailed: bool, default: True

            Whether to print the detailed model profile

        module_depth: int, default: -1

            The depth of the model at which to print the aggregated module information. When set to -1, it prints

            information from the top module to the innermost modules (the maximum depth).

        output_file: Optional[str], default: None

            Path to the output file. If None, the profiler prints to stdout

        profile_step: int, default: 1

            The global training step at which to profile.

        top_modules: int, default: 1

            Limits the aggregated profile output to the number of top modules specified.



        Notes

        -----

        Warm up steps are needed for accurate time measurement



        """



        detailed: bool = True

        module_depth: int = -1

        output_file: Optional[str] = None

        profile_step: int = 1

        top_modules: int = 1

DeepspeedOffloadOptimizerConfig

class DeepspeedOffloadOptimizerConfig(
    buffer_count: int = 4,
    device: stoke.configs.OffloadDevice = 'cpu',
    fast_init: bool = False,
    nvme_path: str = '/local_nvme',
    pin_memory: bool = False,
    pipeline: bool = False,
    pipeline_read: bool = False,
    pipeline_write: bool = False
)

Attributes

Name Type Description Default
buffer_count int, default: 4 Number of buffers in buffer pool for optimizer state offloading to NVMe. This should be at least the number
of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter,
gradient, momentum, and variance). None
device OffloadDevice, default: 'cpu' Device memory to offload optimizer state None
fast_init bool, default: False Enable fast optimizer initialization when offloading to NVMe None
nvme_path str, default: '/local_nvme' Filesystem path for NVMe device for optimizer state offloading None
pin_memory bool, default: False Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. None
pipeline bool, default: False pipeline activated (will default to True if either pipeline_read or pipeline_write is set to
pipeline_read bool, default: False activate pipeline read (deepspeed has limited docs for what this does) None
pipeline_write bool, default: False activate pipeline write(deepspeed has limited docs for what this does) None

??? example "View Source" class DeepspeedOffloadOptimizerConfig:

        """Deepspeed optimizer offloading configuration class



        Attributes

        ----------

        buffer_count: int, default: 4

            Number of buffers in buffer pool for optimizer state offloading to NVMe. This should be at least the number

            of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter,

            gradient, momentum, and variance).

        device: OffloadDevice, default: 'cpu'

            Device memory to offload optimizer state

        fast_init: bool, default: False

            Enable fast optimizer initialization when offloading to NVMe

        nvme_path: str, default: '/local_nvme'

            Filesystem path for NVMe device for optimizer state offloading

        pin_memory: bool, default: False

            Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead.

        pipeline: bool, default: False

            pipeline activated (will default to True if either pipeline_read or pipeline_write is set

        pipeline_read: bool, default: False

            activate pipeline read (deepspeed has limited docs for what this does)

        pipeline_write: bool, default: False

            activate pipeline write(deepspeed has limited docs for what this does)



        """



        buffer_count: int = 4

        device: OffloadDevice = "cpu"

        fast_init: bool = False

        nvme_path: str = "/local_nvme"

        pin_memory: bool = False

        pipeline: bool = False

        pipeline_read: bool = False

        pipeline_write: bool = False

DeepspeedOffloadParamConfig

class DeepspeedOffloadParamConfig(
    buffer_count: int = 5,
    buffer_size: int = 100000000,
    device: stoke.configs.OffloadDevice = 'cpu',
    max_in_cpu: int = 1000000000,
    nvme_path: str = '/local_nvme',
    pin_memory: bool = False
)

Attributes

Name Type Description Default
buffer_count int, default: 5 Number of buffers in buffer pool for parameter offloading to NVMe None
buffer_size int, default: int(1E8) Size of buffers in buffer pool for parameter offloading to NVMe None
device OffloadDevice, default: 'cpu' Device memory to offload model parameters None
max_in_cpu int, default: int(1E9) Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled. None
nvme_path str, default: '/local_nvme' Filesystem path for NVMe device for parameter offloading None
pin_memory bool, default: False Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. None

??? example "View Source" class DeepspeedOffloadParamConfig:

        """Deepspeed parameter offloading configuration class



        Attributes

        ----------

        buffer_count: int, default: 5

            Number of buffers in buffer pool for parameter offloading to NVMe

        buffer_size: int, default: int(1E8)

            Size of buffers in buffer pool for parameter offloading to NVMe

        device: OffloadDevice, default: 'cpu'

            Device memory to offload model parameters

        max_in_cpu: int, default: int(1E9)

            Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled.

        nvme_path: str, default: '/local_nvme'

            Filesystem path for NVMe device for parameter offloading

        pin_memory: bool, default: False

            Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead.



        """



        buffer_count: int = 5

        buffer_size: int = int(1e8)

        device: OffloadDevice = "cpu"

        max_in_cpu: int = int(1e9)

        nvme_path: str = "/local_nvme"

        pin_memory: bool = False

DeepspeedPLDConfig

class DeepspeedPLDConfig(
    theta: float = 1.0,
    gamma: float = 0.001
)

Attributes

Name Type Description Default
theta float, default: 1.0 Hyper-parameter that controls the trade-off between training time and robustness. The lower the theta value,
the faster the training speed None
gamma float, default: 0.001 Hyper-parameter that controls how fast the drop ratio increases None

??? example "View Source" class DeepspeedPLDConfig:

        """

        Attributes

        ----------

        theta: float, default: 1.0

            Hyper-parameter that controls the trade-off between training time and robustness. The lower the theta value,

            the faster the training speed

        gamma: float, default: 0.001

            Hyper-parameter that controls how fast the drop ratio increases



        """



        theta: float = 1.0

        gamma: float = 0.001

DeepspeedTensorboardConfig

class DeepspeedTensorboardConfig(
    output_path: str = '',
    job_name: str = 'DeepSpeedJobName'
)

Attributes

Name Type Description Default
output_path str, default: '' Tensorboard output path None
job_name str, default: 'DeepSpeedJobName' Tensorboard job name None

??? example "View Source" class DeepspeedTensorboardConfig:

        """Deepspeed Tensorboard configuration class



        Attributes

        ----------

        output_path: str, default: ''

            Tensorboard output path

        job_name: str, default: 'DeepSpeedJobName'

            Tensorboard job name



        """



        output_path: str = ""

        job_name: str = "DeepSpeedJobName"

DeepspeedZeROConfig

class DeepspeedZeROConfig(
    allgather_bucket_size: int = 500000000,
    allgather_partitions: bool = True,
    contiguous_gradients: bool = False,
    ignore_unused_parameters: bool = True,
    legacy_stage1: bool = False,
    offload_optimizer: Union[stoke.configs.DeepspeedOffloadOptimizerConfig, NoneType] = None,
    offload_param: Union[stoke.configs.DeepspeedOffloadParamConfig, NoneType] = None,
    overlap_comm: bool = False,
    reduce_bucket_size: int = 500000000,
    reduce_scatter: bool = True,
    stage: int = 0,
    stage3_max_live_parameters: int = 1000000000,
    stage3_max_reuse_distance: int = 1000000000,
    stage3_prefetch_bucket_size: int = 500000000,
    stage3_param_persistence_threshold: int = 1000000,
    stage3_gather_fp16_weights_on_model_save: bool = False,
    sub_group_size: int = 1000000000000
)

Attributes

Name Type Description Default
allgather_bucket_size int, default: int(5E8) Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes None
allgather_partitions bool, default: True Chooses between allgather collective or a series of broadcast collectives to gather updated parameters
from all the GPUs at the end of each step None
contiguous_gradients bool, default: False Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward
pass. Only useful when running very large models. None
ignore_unused_parameters bool, default: True Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload
Enable this option to avoid -- https://github.com/microsoft/DeepSpeed/issues/707 None
legacy_stage1 bool, default: False Use deepspeed < v0.3.17 zero stage 1, kept for backwards compatability reasons None
offload_optimizer Optional[DeepspeedOffloadOptimizerConfig], default: None Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU
memory for larger models or batch sizes. Valid only with stage 3 None
offload_param Optional[DeepspeedOffloadParamConfig], default: None Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch
sizes. Valid only with stage 3. None
overlap_comm bool, default: False Attempts to overlap the reduction of the gradients with backward computation None
reduce_bucket_size int, default: int(5E8) Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large
model sizes None
reduce_scatter bool, default: True Uses reduce or reduce scatter instead of allreduce to average gradients None
stage int, default: 0 Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state
partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning,
respectively None
stage3_max_live_parameters int, default: int(1E9) The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but
perform more communication. None
stage3_max_reuse_distance int, default: int(1E9) Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less
memory, but perform more communication. None
stage3_prefetch_bucket_size int, default: int(5E8) The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase
stalls due to communication. None
stage3_param_persistence_threshold int, default: int(1E6) Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly
increase communication (especially latency-bound messages). None
stage3_gather_fp16_weights_on_model_save bool, default: False Consolidate the weights before saving the model by save_fp16_model(). Since the weights are partitioned
across GPUs, they aren’t part of state_dict, so this function automatically gather the weights when this
option is enabled and then saves the fp16 model weights. None
sub_group_size int, default: int(1E12) sub_group_size controls the granularity in which parameters are updated during optimizer steps. Parameters are
grouped into buckets of sub_group_size and each buckets is updated one at a time. None

??? example "View Source" class DeepspeedZeROConfig:

        """Deepspeed ZeRO configuration class



        Attributes

        ----------

        allgather_bucket_size: int, default: int(5E8)

            Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes

        allgather_partitions: bool, default: True

            Chooses between allgather collective or a series of broadcast collectives to gather updated parameters

            from all the GPUs at the end of each step

        contiguous_gradients: bool, default: False

            Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward

            pass. Only useful when running very large models.

        ignore_unused_parameters: bool, default: True

            Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload

            Enable this option to avoid -- https://github.com/microsoft/DeepSpeed/issues/707

        legacy_stage1: bool, default: False

            Use deepspeed < v0.3.17 zero stage 1, kept for backwards compatability reasons

        offload_optimizer: Optional[DeepspeedOffloadOptimizerConfig], default: None

            Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU

            memory for larger models or batch sizes. Valid only with stage 3

        offload_param: Optional[DeepspeedOffloadParamConfig], default: None

            Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch

            sizes. Valid only with stage 3.

        overlap_comm: bool, default: False

            Attempts to overlap the reduction of the gradients with backward computation

        reduce_bucket_size: int, default: int(5E8)

            Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large

            model sizes

        reduce_scatter: bool, default: True

            Uses reduce or reduce scatter instead of allreduce to average gradients

        stage: int, default: 0

            Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state

            partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning,

            respectively

        stage3_max_live_parameters: int, default: int(1E9)

            The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but

            perform more communication.

        stage3_max_reuse_distance: int, default: int(1E9)

            Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less

            memory, but perform more communication.

        stage3_prefetch_bucket_size: int, default: int(5E8)

            The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase

            stalls due to communication.

        stage3_param_persistence_threshold: int, default: int(1E6)

            Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly

            increase communication (especially latency-bound messages).

        stage3_gather_fp16_weights_on_model_save: bool, default: False

            Consolidate the weights before saving the model by save_fp16_model(). Since the weights are partitioned

            across GPUs, they aren’t part of state_dict, so this function automatically gather the weights when this

            option is enabled and then saves the fp16 model weights.

        sub_group_size: int, default: int(1E12)

            sub_group_size controls the granularity in which parameters are updated during optimizer steps. Parameters are

            grouped into buckets of sub_group_size and each buckets is updated one at a time.



        """



        allgather_bucket_size: int = int(5e8)

        allgather_partitions: bool = True

        contiguous_gradients: bool = False

        ignore_unused_parameters: bool = True

        legacy_stage1: bool = False

        offload_optimizer: Optional[DeepspeedOffloadOptimizerConfig] = None

        offload_param: Optional[DeepspeedOffloadParamConfig] = None

        overlap_comm: bool = False

        reduce_bucket_size: int = int(5e8)

        reduce_scatter: bool = True

        stage: int = 0

        stage3_max_live_parameters: int = int(1e9)

        stage3_max_reuse_distance: int = int(1e9)

        stage3_prefetch_bucket_size: int = int(5e8)

        stage3_param_persistence_threshold: int = int(1e6)

        stage3_gather_fp16_weights_on_model_save: bool = False

        sub_group_size: int = int(1e12)

DistributedOptions

class DistributedOptions(
    /,
    *args,
    **kwargs
)

??? example "View Source" class DistributedOptions(Enum):

        """Enum that defines the options for Distributed backends"""



        horovod = "horovod"

        ddp = "ddp"

        deepspeed = "deepspeed"

Ancestors (in MRO)

  • enum.Enum

Class variables

ddp
deepspeed
horovod
name
value

FP16Options

class FP16Options(
    /,
    *args,
    **kwargs
)

??? example "View Source" class FP16Options(Enum):

        """Enum that defines the options for FP16 backends"""



        apex_O1 = "apex_O1"

        apex_O2 = "apex_O2"

        amp = "amp"

        deepspeed = "deepspeed"

Ancestors (in MRO)

  • enum.Enum

Class variables

amp
apex_O1
apex_O2
deepspeed
name
value

FairscaleFSDPConfig

class FairscaleFSDPConfig(
    bucket_cap_mb: int = 25,
    buffer_dtype: Union[torch.dtype, NoneType] = None,
    clear_autocast_cache: bool = False,
    compute_dtype: Union[torch.dtype, NoneType] = None,
    flatten_parameters: bool = True,
    force_input_to_fp32: bool = False,
    fp32_reduce_scatter: bool = False,
    gradient_predivide_factor: Union[float, NoneType] = None,
    gradient_postdivide_factor: Union[float, NoneType] = None,
    move_grads_to_cpu: Union[bool, NoneType] = None,
    move_params_to_cpu: bool = False,
    no_broadcast_optim_state: Union[bool, NoneType] = False,
    reshard_after_forward: bool = True,
    verbose: bool = False
)

Attributes

Name Type Description Default
bucket_cap_mb int, default: 25 FSDP will bucket parameters so that gradient reduction can be more efficient for small parameters.
bucket_cap_mb controls the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the
max shard size is roughly bucket_cap_mb / world_size. There is one bucketer (with potentially multiple
bucket_cap_mb sized buffers shared by all FSDP instances. Large gradient tensors are directly reduced without
using the buffers. The buffers are there to reduce communication overhead for small tensors. Overlapping with
computation happens due to use of a different CUDA stream than the computation CUDA stream. The total memory
overhead per buffer is around bucket_cap_mb / world_size * (world_size + 1). The buffers are allocated during
the backward pass and freed at the end of the backward pass to save more memory for other phases of the
training process. Note, the memory vs. speed tradeoff of bucket size is very different from that of the DDP
engine. In DDP, the buffer size 1MB + n*cap_mb, until n is big enough to cover the entire model size. The
order of which buffer is ready there is more rigid and DDP requires all gradients to be computed in the
backward. In FSDP, the buffer size does not change with model size (it changes based on number of
tuples) and gradient ready order matters little since FSDP has a final flush
call that ensures everything is reduced and not all gradients need to be upfront known. Overlapping with
compute is done differently too. Values <= 0 disable bucketing None
buffer_dtype Optional[torch.dtype], default: None dtype for buffers for computation. defaults to value of compute_dtype value
clear_autocast_cache bool, default: False When using mixed precision training with FP16 AMP, if the model weights are in FP32, autocast
maintains a cache for downcasted weights. The cache can cause GPU OOM during the forward pass. Setting this
flag to true will help clearing this cache as inner FSDP instances finish part of the forward pass to save
GPU memory None
compute_dtype Optional[torch.dtype], default: None dtype for full parameters for computation. This defaults to torch.float32 unless FP 16 AMP is set,
in which case it defaults to torch.float16. torch.float32
flatten_parameters bool, default: True flatten parameters into a single contiguous tensor, which improves training speed None
force_input_to_fp32 bool, default: False: force input floating point tensors to be FP32 (if they are FP16) when the FSDP instance is in full precision
mode. This helps avoid issues of running SyncBatchNorm with AMP and checkpoint_wrapper. None
fp32_reduce_scatter bool, default: False reduce-scatter gradients in FP32. This is only relevant when FP16 AMP is used None
gradient_predivide_factor Optional[float], default: None divide factor before the reduction None
gradient_postdivide_factor Optional[float], default: None divide factor after the reduction None
move_grads_to_cpu Optional[bool], default: None move gradient shard to CPU after reduction. This is only relevant when FP16 AMP is used None
move_params_to_cpu bool, default: False offload FP32 params to CPU. This is only relevant when FP16 AMP is used None
no_broadcast_optim_state Optional[bool], default: False do not broadcast this modules optimizer state when gather_full_optim_state_dict is called. If you set this
true, you are expected to overwrite the relevant state entries of the returned optimizer state dict with the
proper state at each rank. This is useful for situations, like Mixture Of Experts, where all but a few
parameters can fit on one node None
reshard_after_forward bool, default: True reshard parameters after the forward pass. This saves memory but slows training. This is only relevant
when resharding individual layers (see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html) None
verbose bool, default: True turn on verbose output for model’s string representation None

??? example "View Source" class FairscaleFSDPConfig:

        """Fairscale Fully Sharded Data Parallel configuration class



        Attributes

        ----------

        bucket_cap_mb: int, default: 25

            FSDP will bucket parameters so that gradient reduction can be more efficient for small parameters.

            bucket_cap_mb controls the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the

            max shard size is roughly bucket_cap_mb / world_size. There is one bucketer (with potentially multiple

            bucket_cap_mb sized buffers shared by all FSDP instances. Large gradient tensors are directly reduced without

            using the buffers. The buffers are there to reduce communication overhead for small tensors. Overlapping with

            computation happens due to use of a different CUDA stream than the computation CUDA stream. The total memory

            overhead per buffer is around bucket_cap_mb / world_size * (world_size + 1). The buffers are allocated during

            the backward pass and freed at the end of the backward pass to save more memory for other phases of the

            training process. Note, the memory vs. speed tradeoff of bucket size is very different from that of the DDP

            engine. In DDP, the buffer size 1MB + n*cap_mb, until n is big enough to cover the entire model size. The

            order of which buffer is ready there is more rigid and DDP requires all gradients to be computed in the

            backward. In FSDP, the buffer size does not change with model size (it changes based on number of

            <dtype, device, process_group> tuples) and gradient ready order matters little since FSDP has a final flush

            call that ensures everything is reduced and not all gradients need to be upfront known. Overlapping with

            compute is done differently too. Values <= 0 disable bucketing

        buffer_dtype: Optional[torch.dtype], default: None

            dtype for buffers for computation. defaults to value of compute_dtype

        clear_autocast_cache: bool, default: False

            When using mixed precision training with FP16 AMP, if the model weights are in FP32, autocast

            maintains a cache for downcasted weights. The cache can cause GPU OOM during the forward pass. Setting this

            flag to true will help clearing this cache as inner FSDP instances finish part of the forward pass to save

            GPU memory

        compute_dtype: Optional[torch.dtype], default: None

            dtype for full parameters for computation. This defaults to torch.float32 unless FP 16 AMP is set,

            in which case it defaults to torch.float16.

        flatten_parameters: bool, default: True

            flatten parameters into a single contiguous tensor, which improves training speed

        force_input_to_fp32: bool, default: False:

            force input floating point tensors to be FP32 (if they are FP16) when the FSDP instance is in full precision

            mode. This helps avoid issues of running SyncBatchNorm with AMP and checkpoint_wrapper.

        fp32_reduce_scatter: bool, default: False

            reduce-scatter gradients in FP32. This is only relevant when FP16 AMP is used

        gradient_predivide_factor: Optional[float], default: None

            divide factor before the reduction

        gradient_postdivide_factor: Optional[float], default: None

            divide factor after the reduction

        move_grads_to_cpu: Optional[bool], default: None

            move gradient shard to CPU after reduction. This is only relevant when FP16 AMP is used

        move_params_to_cpu: bool, default: False

            offload FP32 params to CPU. This is only relevant when FP16 AMP is used

        no_broadcast_optim_state: Optional[bool], default: False

            do not broadcast this modules optimizer state when gather_full_optim_state_dict is called. If you set this

            true, you are expected to overwrite the relevant state entries of the returned optimizer state dict with the

            proper state at each rank. This is useful for situations, like Mixture Of Experts, where all but a few

            parameters can fit on one node

        reshard_after_forward: bool, default: True

            reshard parameters after the forward pass. This saves memory but slows training. This is only relevant

            when resharding individual layers (see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html)

        verbose: bool, default: True

            turn on verbose output for model’s string representation



        Notes

        -----

        mixed_precision: bool

            This value will automatically be set from the Stoke FP16 selected option (AMP only)

        state_dict_device: torch.device

            this is not exposed as it should be managed internally from the DDP backend setup

        compute_device: torch.device

            this is not exposed as it should be managed internally from the DDP backend setup



        """



        bucket_cap_mb: int = 25

        buffer_dtype: Optional[torch.dtype] = None

        clear_autocast_cache: bool = False

        compute_dtype: Optional[torch.dtype] = None

        flatten_parameters: bool = True

        force_input_to_fp32: bool = False

        fp32_reduce_scatter: bool = False

        gradient_predivide_factor: Optional[float] = None

        gradient_postdivide_factor: Optional[float] = None

        move_grads_to_cpu: Optional[bool] = None

        move_params_to_cpu: bool = False

        no_broadcast_optim_state: Optional[bool] = False

        reshard_after_forward: bool = True

        verbose: bool = False

Descendants

  • stoke.extensions._FairscaleFSDPConfig

FairscaleOSSConfig

class FairscaleOSSConfig(
    broadcast_fp16: bool = False
)

Attributes

Name Type Description Default
broadcast_fp16 bool, default: False Compress the model shards in fp16 before sharing them in between ranks. This is safe to use when PyTorch AMP
is activated. Without torch AMP this will lead to a slight degradation in terms of accuracy. None

??? example "View Source" class FairscaleOSSConfig:

        """Fairscale optimizer state sharding configuration class



        Attributes

        ----------

        broadcast_fp16: bool, default: False

            Compress the model shards in fp16 before sharing them in between ranks. This is safe to use when PyTorch AMP

            is activated. Without torch AMP this will lead to a slight degradation in terms of accuracy.



        """



        broadcast_fp16: bool = False

FairscaleSDDPConfig

class FairscaleSDDPConfig(
    auto_refresh_trainable: bool = True,
    broadcast_buffers: bool = True,
    reduce_buffer_size: int = 8388608,
    reduce_fp16: bool = False,
    sync_models_at_startup: bool = True
)

Attributes

Name Type Description Default
auto_refresh_trainable bool, default: True Check whether the parameters trainability (requires_grad) has changed and update both ShardedDDP and OSS
automatically if this is the case. If set to False, refresh_trainable() needs to be called anytime a
parameter is frozen or unfrozen None
broadcast_buffers bool, default: True Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass. Same
setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters. None
reduce_buffer_size int, default: 2 ** 23 he max size of the buffer used to batch the small parameter tensors, in number of elements. This will impact
the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing, 1M to 8M is usually reasonable. None
reduce_fp16 bool, default: False cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve
performance for multi node jobs using PyTorch AMP. The effect is similar to DDP’s fp16_compress_hook and
will also save some memory. None
sync_models_at_startup bool, default: True Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, or
the training restarts from a saved state None

??? example "View Source" class FairscaleSDDPConfig:

        """Fairscale sharded data parallel (SDDP) configuration class



        Attributes

        ----------

        auto_refresh_trainable: bool, default: True

            Check whether the parameters trainability (requires_grad) has changed and update both ShardedDDP and OSS

            automatically if this is the case. If set to False, refresh_trainable() needs to be called anytime a

            parameter is frozen or unfrozen

        broadcast_buffers: bool, default: True

            Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass. Same

            setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.

        reduce_buffer_size: int, default: 2 ** 23

            he max size of the buffer used to batch the small parameter tensors, in number of elements. This will impact

            the long term memory consumption, because these buckets correspond to parameters which will not be sharded.

            Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.

        reduce_fp16: bool, default: False

            cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve

            performance for multi node jobs using PyTorch AMP. The effect is similar to DDP’s fp16_compress_hook and

            will also save some memory.

        sync_models_at_startup: bool, default: True

            Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, or

            the training restarts from a saved state



        """



        auto_refresh_trainable: bool = True

        broadcast_buffers: bool = True

        reduce_buffer_size: int = 2 ** 23

        reduce_fp16: bool = False

        sync_models_at_startup: bool = True

HorovodConfig

class HorovodConfig(
    compression: bool = False,
    convert_to_sync_batch_norm: bool = False,
    gradient_predivide_factor: float = 1.0,
    op: stoke.configs.HorovodOps = 'Average'
)

Attributes

Name Type Description Default
compression bool, default: False Compression algorithm used during allreduce to reduce the amount of data sent during the each parameter
update step. None
convert_to_sync_batch_norm bool, default: False Automatically convert all batch norm calls to horovod.torch.SyncBatchNorm calls
https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.SyncBatchNorm None
gradient_predivide_factor float, default: 1.0 If op == Average, gradient_predivide_factor splits the averaging before and after the sum. Gradients are scaled
by 1.0 / gradient_predivide_factor before the sum and gradient_predivide_factor / size after the sum. None
op HorovodOps, default: 'Average' The reduction operation to use when combining gradients across different ranks. None

??? example "View Source" class HorovodConfig:

        """Horovod configuration class



        Attributes

        ----------

        compression: bool, default: False

            Compression algorithm used during allreduce to reduce the amount of data sent during the each parameter

            update step.

        convert_to_sync_batch_norm: bool, default: False

            Automatically convert all batch norm calls to horovod.torch.SyncBatchNorm calls

            https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.SyncBatchNorm

        gradient_predivide_factor: float, default: 1.0

            If op == Average, gradient_predivide_factor splits the averaging before and after the sum. Gradients are scaled

            by 1.0 / gradient_predivide_factor before the sum and gradient_predivide_factor / size after the sum.

        op: HorovodOps, default: 'Average'

            The reduction operation to use when combining gradients across different ranks.



        """



        compression: bool = False

        convert_to_sync_batch_norm: bool = False

        gradient_predivide_factor: float = 1.0

        op: HorovodOps = "Average"

ParamNormalize

class ParamNormalize(
    /,
    *args,
    **kwargs
)

??? example "View Source" class ParamNormalize(Enum):

        """Normalization enum for total number of model parameters used to help with a pretty print"""



        THOUSAND = 1e3

        MILLION = 1e6

        BILLION = 1e9

        TRILLION = 1e12

Ancestors (in MRO)

  • enum.Enum

Class variables

BILLION
MILLION
THOUSAND
TRILLION
name
value

Stoke

class Stoke(
    model: torch.nn.modules.module.Module,
    optimizer: stoke.configs.StokeOptimizer,
    loss: Union[Callable, List[Callable], Tuple[Callable]],
    batch_size_per_device: int,
    grad_accum_steps: Union[int, NoneType] = 1,
    grad_clip: Union[stoke.configs.ClipGradConfig, stoke.configs.ClipGradNormConfig, NoneType] = None,
    gpu: bool = False,
    fp16: Union[stoke.status.FP16Options, NoneType] = None,
    distributed: Union[stoke.status.DistributedOptions, NoneType] = None,
    fairscale_oss: bool = False,
    fairscale_sddp: bool = False,
    fairscale_fsdp: bool = False,
    configs: Union[List[Union[stoke.configs.AMPConfig, stoke.configs.ApexConfig, stoke.configs.DDPConfig, stoke.configs.DeepspeedConfig, stoke.configs.FairscaleOSSConfig, stoke.configs.FairscaleSDDPConfig, stoke.configs.FairscaleFSDPConfig, stoke.configs.HorovodConfig]], NoneType] = None,
    info_rank: Union[int, List[int], NoneType] = 0,
    verbose: bool = True,
    ema_weight: float = 0.1
)

Attributes

Name Type Description Default
amp_config None None None
apex_config None None None
batch_size None None None
cuda None None None
ddp_config None None None
deepspeed_config None None None
distributed None None None
effective_batch_size None None None
ema_loss None None None
fp16 None None None
fsdp_config None None None
fully_sharded None None None
gpu None None None
grad_accum None None None
grad_clip None None None
horovod_config None None None
is_amp None None None
is_apex None None None
is_ddp None None None
is_deepspeed None None None
is_horovod None None None
loss_access None None None
model_access None None None
nccl None None None
num_model_parameters None None None
optimizer None None None
oss None None None
oss_config None None None
rank None None None
scaler None None None
sddp_config None None None
sharded None None None
status None None None
world_size None None None
_agg_loss Union[float, List[float], Tuple[float]] aggregated loss for grad accumulation (single or multiple losses) None
_backward_steps int Number of times gradients have been calculated on a batch of samples (calls to backward) None
_grad_accum_counter int counter for grad accumulation steps None
_loss Union[Callable, List[Callable], Tuple[Callable]] callable function that calculates a loss from the model outputs None
_last_step_loss list, tuple, or float last loss step calculation aggregated over device(s) None
_model torch.nn.Module instance of torch.nn.Module for Stoke to handle None
_optimizer StokeOptimizer StokeOptimizer config object that describes the torch.optim.Optimizer and it's kwargs None
_optimizer_steps int Number of times step has been called on the optimizer None
_runner StokeRunner the dynamically created runtime object that handles all ops None
_status StokeStatus StokeStatus object that sets and maintains the current configuration None
_verbose bool print verbosity None
_rolling_loss_steps int number of steps that have been called for the rolling loss None
_rolling_mean_loss list, tuple, or float current ema loss None
_ema_weight float weight used for any ema calculation on metrics None

??? example "View Source" class Stoke:

        """High level stoke object that manages all necessary configs and provides a unified interface to ops



        This is the main class within Stoke. Functionally it manages all interfaces to the necessary wrapped ops (model,

        loss, backward, step), provides helper functions, and dynamically constructs the runtime that handles the

        combinatorics problem of underlying frameworks (DDP, Horovod, Deepspeed, Fairscale),

        mixed-precision (AMP or APEX) and devices (CPU or GPU)



        Attributes

        ----------

        amp_config

        apex_config

        batch_size

        cuda

        ddp_config

        deepspeed_config

        distributed

        effective_batch_size

        ema_loss

        fp16

        fsdp_config

        fully_sharded

        gpu

        grad_accum

        grad_clip

        horovod_config

        is_amp

        is_apex

        is_ddp

        is_deepspeed

        is_horovod

        loss_access

        model_access

        nccl

        num_model_parameters

        optimizer

        oss

        oss_config

        rank

        scaler

        sddp_config

        sharded

        status

        world_size

        _agg_loss: Union[float, List[float], Tuple[float]]

            aggregated loss for grad accumulation (single or multiple losses)

        _backward_steps: int

            Number of times gradients have been calculated on a batch of samples (calls to backward)

        _grad_accum_counter: int

            counter for grad accumulation steps

        _loss: Union[Callable, List[Callable], Tuple[Callable]]

            callable function that calculates a loss from the model outputs

        _last_step_loss: list, tuple, or float

            last loss step calculation aggregated over device(s)

        _model: torch.nn.Module

            instance of torch.nn.Module for Stoke to handle

        _optimizer: StokeOptimizer

            StokeOptimizer config object that describes the torch.optim.Optimizer and it's kwargs

        _optimizer_steps: int

            Number of times step has been called on the optimizer

        _runner: StokeRunner

            the dynamically created runtime object that handles all ops

        _status: StokeStatus

            StokeStatus object that sets and maintains the current configuration

        _verbose: bool

            print verbosity

        _rolling_loss_steps: int

            number of steps that have been called for the rolling loss

        _rolling_mean_loss: list, tuple, or float

            current ema loss

        _ema_weight: float

            weight used for any ema calculation on metrics



        """



        def __init__(

            self,

            model: torch.nn.Module,

            optimizer: StokeOptimizer,

            loss: Union[Callable, List[Callable], Tuple[Callable]],

            batch_size_per_device: int,

            grad_accum_steps: Optional[int] = 1,

            grad_clip: Optional[Union[ClipGradConfig, ClipGradNormConfig]] = None,

            gpu: bool = False,

            fp16: Optional[FP16Options] = None,

            distributed: Optional[DistributedOptions] = None,

            fairscale_oss: bool = False,

            fairscale_sddp: bool = False,

            fairscale_fsdp: bool = False,

            configs: Optional[

                List[

                    Union[

                        AMPConfig,

                        ApexConfig,

                        DDPConfig,

                        DeepspeedConfig,

                        FairscaleOSSConfig,

                        FairscaleSDDPConfig,

                        FairscaleFSDPConfig,

                        HorovodConfig,

                    ]

                ]

            ] = None,

            info_rank: Optional[Union[int, List[int]]] = 0,

            verbose: bool = True,

            ema_weight: float = 0.1,

        ):

            """Init for Stoke class object



            Parameters

            ----------

            model: torch.nn.Module

                PyTorch model

            optimizer: StokeOptimizer

                Optimizer configuration

            loss: Union[Callable, List[Callable], Tuple[Callable]]

                Callable loss function or functions

            batch_size_per_device: int

                Batch size at the single device level

            grad_accum_steps: Optional[int], default: 1

                Number of gradient accumulation steps

            grad_clip: Optional[Union[ClipGradConfig, ClipGradNormConfig]], default: None

                Gradient clipping configuration

            gpu: bool, default: False

                flag to use GPU device(s)

            fp16: Optional[FP16Options], default: None

                Choice of mixed-precision backend

            distributed: Optional[DistributedOptions], default: None

                Choice of distributed backend

            fairscale_oss: bool, default: False

                Flag to activate optimizer state sharding using Fairscale

            fairscale_sddp: bool, default: False

                Flag to activate sharded DDP using Fairscale

            fairscale_fsdp: bool, default: False

                Flag to activate fully sharded DDP using Fairscale

            configs: Optional[List[Union[AMPConfig, ApexConfig, DDPConfig, DeepspeedConfig, FairscaleOSSConfig, FairscaleSDDPConfig, FairscaleFSDPConfig, HorovodConfig]], default: None

                Configuration objects for runtimes

            info_rank: Optional[Union[int, List[int]]], default = 0

                Constrain prints to specific devices

            verbose: bool, default: True

                Flag for verbosity

            ema_weight: float, default: 0.5

                weight used for any ema calculation on metrics



            """

            # Verbosity

            self._verbose = verbose

            # Info rank

            self._info_rank = info_rank

            # EMA

            self._ema_weight = ema_weight

            # Setup the StokeState

            self._status = StokeStatus(

                batch_size_per_device=batch_size_per_device,

                grad_accum=grad_accum_steps,

                grad_clip=grad_clip,

                gpu=gpu,

                fp16=fp16,

                distributed=distributed,

                fairscale_oss=fairscale_oss,

                fairscale_sddp=fairscale_sddp,

                fairscale_fsdp=fairscale_fsdp,

                configs=configs,

            )

            # Run some checks

            self._model = self._check_model(model)

            self._optimizer = self._check_optimizer(optimizer)

            self._loss = self._check_loss(loss)

            # Dynamically construct the StokeRunner from the StokeStatus

            self._runner, class_info = self._build_runner()

            # Setup distributed backend

            self._runner.setup_distributed()

            # Post here the runner will have the print_device function that is mapped to the self.print here

            # as it needs rank to be accessible before working

            if self._verbose:

                dev_id = (

                    self.rank

                    if (self.rank == "cpu" or self.rank == "gpu")

                    else self._info_rank

                )

                self.print(f"Printing verbose information on rank(s): {dev_id}")

                # Print the runner class info from the mixins

                self.print(class_info)

            # Possibly place model on GPU depending on StokeStatus -- before wrap calls

            self._place_model_on_gpu()

            # Handle the wrap ops in the correct order

            self._handle_ordered_wrap_ops(optimizer=optimizer)

            # Create some tracking vars

            self._grad_accum_counter = 0

            self._optimizer_steps = 0

            self._backward_steps = 0

            self._last_step_loss = self._set_loss_to_zero()

            self._agg_loss = self._set_loss_to_zero()

            self._rolling_mean_loss = self._set_loss_to_zero()

            self._rolling_loss_steps = 0

            # Set post-init status variables

            self._status.set_post_init_values(world_size=self.world_size)

            # Print the final configuration

            if self._verbose:

                self.print(msg=self._status)



        def _wrap_optimizer_then_model(self, optimizer: StokeOptimizer):

            """Handles wrapping of optimizer then the model



            This holds only for SDDP, Horovod, and APEX as these need to use an instantiated optimizer before wrapped

            methods are called



            Parameters

            ----------

            optimizer: StokeOptimizer

                Optimizer configuration



            Returns

            -------

            None



            """

            # Build the optimizer

            self._optimizer = self._runner.build_optimizer(

                optimizer=optimizer["optimizer"],

                optimizer_kwargs=optimizer["optimizer_kwargs"],

                model=self._model,

            )

            # Setup/Initialize FP16 backend -- in this case the optimizer is passed through

            self._runner.wrap_fp16(model=self._model, optimizer=self._optimizer)

            # Wrap with distributed backend -- in this case the optimizer is passed through

            self._model, self._optimizer = self._runner.wrap_distributed(

                model=self._model, grad_accum=self.grad_accum, optimizer=self._optimizer

            )



        def _wrap_model_then_optimizer(self, optimizer: StokeOptimizer):

            """Handles wrapping of model then optimizer



            Parameters

            ----------

            optimizer: StokeOptimizer

                Optimizer configuration



            Returns

            -------

            None



            """

            # Wrap with distributed backend -- in this case the optimizer is passed as None since it doesn't exist yet

            # don't use the return for the optimizer in this case

            self._model, _ = self._runner.wrap_distributed(

                model=self._model, grad_accum=self.grad_accum, optimizer=None

            )

            # Setup/Initialize FP16 backend -- in this case the optimizer is passed as None since it doesn't exist yet

            self._runner.wrap_fp16(model=self._model, optimizer=None)

            # Build the optimizer

            self._optimizer = self._runner.build_optimizer(

                optimizer=optimizer["optimizer"],

                optimizer_kwargs=optimizer["optimizer_kwargs"],

                model=self._model,

            )



        def _handle_ordered_wrap_ops(self, optimizer: StokeOptimizer):

            """Handles wrapping model, using FP16, and wrapping optimizer in the correct order depending on Stoke Status



            Parameters

            ----------

            optimizer: StokeOptimizer

                Optimizer configuration



            Returns

            -------

            None



            """

            # if SDDP + OSS, Horovod, and APEX then we need to make sure that the optimizer gets wrapped before the model

            # gets wrapped, all other models follow standard DDP paradigm (or their own DeepSpeed)

            if (self.sharded and self.oss) or self.is_apex or self.is_horovod:

                self._wrap_optimizer_then_model(optimizer=optimizer)

            else:

                self._wrap_model_then_optimizer(optimizer=optimizer)



        def _check_accum(self):

            """Checks if the current step is the last accumulation step



            Returns

            -------

            bool



            """

            return (self._grad_accum_counter + 1) % (self.grad_accum + 1) == 0



        def _check_pre_accum(self):

            """Checks if we are at the pre-accumulate step



            Returns

            -------

            bool



            """

            return (self._grad_accum_counter + 1) % (self.grad_accum + 1) == self.grad_accum



        def _set_loss_to_zero(self):

            """Used to set a loss tracker to zero depending on the type



            Returns

            -------

            float or list or tuple of reset loss



            """

            return (

                type(self._loss)([0.0] * len(self._loss))

                if isinstance(self._loss, (list, tuple))

                else 0.0

            )



        def reset_ema(self):

            """Used to reset the current state of the rolling mean loss



            Returns

            -------

            None



            """

            self._rolling_mean_loss = self._set_loss_to_zero()

            self._rolling_loss_steps = 0



        def print_ema_loss(

            self, prepend_msg: str = "Current EMA Loss", single_line: bool = False

        ):

            """Prints the current ema loss synced across all devices



            Handles single or multiple losses. Prints only on devices specified by self._info_rank



            Parameters

            ----------

            prepend_msg: str, default: "Current EMA Loss"

                message prepend to print

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            if isinstance(self._rolling_mean_loss, (list, tuple)):

                print_vals = [

                    f"{prepend_msg} {idx}: {val:.3f}"

                    for idx, val in enumerate(self._rolling_mean_loss)

                ]

                self.print(print_vals, single_line=single_line)

            else:

                self.print(f"{prepend_msg}: {self._rolling_mean_loss:.3f}")



        def print_mean_accumulated_synced_loss(

            self,

            prepend_msg: str = "Mean Accumulated & Synced Loss",

            pre_backwards: bool = True,

            single_line: bool = False,

        ):

            """Prints the mean accumulated and device synced loss only after the grad accumulation step



            Handles single or multiple losses. Prints only on devices specified by self._info_rank



            Parameters

            ----------

            prepend_msg: str, default: "Mean Accumulated & Synced Loss"

                message prepend to print

            pre_backwards: bool, default: True

                if being called pre backward step

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            check_fn = self._check_pre_accum if pre_backwards else self._check_accum

            if check_fn():

                if isinstance(self._agg_loss, (list, tuple)):

                    print_vals = self._scale_agg_loss()

                    self.print(print_vals, single_line=single_line)

                else:

                    self.print(f"{prepend_msg}: {self._scale_agg_loss():.3f}")



        def _scale_agg_loss(self):

            """Scales the mean aggregated loss by  grad accum



            Returns

            -------

            scale_vals: list or float of mean aggregated loss



            """

            if isinstance(self._agg_loss, (list, tuple)):

                scale_vals = [

                    val / self.grad_accum for idx, val in enumerate(self._agg_loss)

                ]

            else:

                scale_vals = self._agg_loss / self.grad_accum

            return scale_vals



        def print_synced_loss(

            self,

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],

            prepend_msg: str = "Step Synced Loss",

            device=None,

            single_line: bool = False,

        ):

            """Prints a device synced loss at a single step



            Handles single or multiple losses. Prints only on devices specified by self._info_rank



            Parameters

            ----------

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

                current loss(es) on the device

            prepend_msg: str, default: "Step Synced Loss"

                message prepend to print

            device: default: None

                specify the device to place the synced loss on (defaults to same device)

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            printable_loss = self.detach_and_sync_loss(loss, device)

            if isinstance(printable_loss, (list, tuple)):

                print_vals = [

                    f"{prepend_msg} {idx}: {val * self.grad_accum:.3f}"

                    for idx, val in enumerate(printable_loss)

                ]

                self.print(print_vals, single_line=single_line)

            else:

                self.print(msg=f"{prepend_msg}: {printable_loss * self.grad_accum:.3f}")



        def print_on_devices(

            self, msg: Union[str, List[str]], rank: Optional[Union[int, List[int]]] = 0

        ):

            """Wraps runner print interface for shorter semantics



            Parameters

            ----------

            msg: str

                message to print

            rank: Union[int, List[int]], default: 0

                which ranks to print on



            Returns

            -------

            None



            """

            self._runner.print_device(msg=msg, rank=rank)



        def print(self, msg: Union[str, List[str]], single_line: bool = False):

            """Wraps the runners print device and forces print on the _info_rank attribute(s)



            Parameters

            ----------

            msg: str

                message to print

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            self._runner.print_device(

                msg=msg, rank=self._info_rank, single_line=single_line

            )



        @staticmethod

        def _check_model(model: torch.nn.Module):

            """Verifies the type of the model



            Parameters

            ----------

            model: torch.nn.Module

                current torch model



            Returns

            -------

            None



            """

            # Check if the model is an nn.Module such that it has a forward method

            if not isinstance(model, torch.nn.Module):

                raise TypeError(

                    f"Stoke -- Model is not of type torch.nn.Module, currently {type(model)}"

                )

            return model



        @staticmethod

        def _check_optimizer(optimizer: StokeOptimizer):

            """Verifies the type of the optimizer



            Parameters

            ----------

            optimizer: StokeOptimizer

                Current optimizer configuration TypedDict (aka dict)



            Returns

            -------

            None



            """

            if not isinstance(optimizer, dict):

                raise TypeError(

                    f"Stoke -- Optimizer is not of type torch.optim.Optimizer, currently {type(optimizer)}"

                )

            return optimizer



        def _check_loss(self, loss: Union[Callable, List[Callable], Tuple[Callable]]):

            """Checks to make sure the loss function(s) is/are callable



            Parameters

            ----------

            loss: Union[Callable, List[Callable], Tuple[Callable]]

                Current callable loss(es)



            Returns

            -------

            None



            """

            if isinstance(loss, (list, tuple)):

                loss = [self._check_loss(val) for val in loss]

                return loss

            elif isinstance(loss, Callable):

                return loss

            else:

                raise TypeError(

                    f"Stoke -- Loss is not of type Callable, currently {type(loss)}"

                )



        def _place_model_on_gpu(self):

            """Automatically moves the model to GPU device(s)



            Returns

            -------

            None



            """

            if self.gpu and not self.is_deepspeed:

                if self._verbose:

                    self.print(f"Automatically handling moving model to GPU(s)...")

                self._model.cuda()



        def _build_runner(self):

            """Builds the runtime object from the mixin style classes



            Mixes the distributed class, fp16 class, and optimizer class into a single object such that all can be called

            from the same interface. Prevents verbose calls to multiple objects and unifies all functionality under a

            a single interface. Might prevent some IDE type-hinting as it's dynamic



            Returns

            -------

            StokeRunner

                runtime runner object



            """

            # Get the classes

            dist_class = self._get_distributed_mixin()

            fp16_class = self._get_fp16_mixin()

            optimizer_class = self._get_optimizer_mixin()

            io_class = self._get_io_mixin()



            # Python MRO hack to make sure the inits of all the Mixin classes get called

            def __multiple_mixin_init__(*args, **kwargs):

                dist_class.__init__(*args, **kwargs)

                fp16_class.__init__(*args, **kwargs)

                optimizer_class.__init__(*args, **kwargs)

                io_class.__init__(*args, **kwargs)



            # Configs pass through

            kwargs_dict = {

                "amp_config": self.amp_config,

                "apex_config": self.apex_config,

                "ddp_config": self.ddp_config,

                "deepspeed_config": self.deepspeed_config,

                "horovod_config": self.horovod_config,

                "oss_config": self.oss_config,

                "sharded_config": self.sddp_config,

                "fully_sharded_config": self.fsdp_config,

            }

            # Generate the runner class from the mixins based on the StokeStatus

            runner_class = type(

                "StokeRunner",

                (dist_class, fp16_class, optimizer_class, io_class),

                {"__init__": __multiple_mixin_init__},

            )(

                verbose=self._verbose,

                batch_size_per_device=self.batch_size,

                grad_accum_steps=self.grad_accum,

                grad_clip=self.grad_clip,

                info_rank=self._info_rank,

                loss=self._loss,

                **kwargs_dict,

            )

            # Make a list of class info for print later

            class_info = [

                f"Distributed Mixin: {dist_class.__name__}",

                f"Optimizer Mixin: {dist_class.__name__}",

                f"FP16 Mixin: {fp16_class.__name__}",

                f"IO Mixin: {io_class.__name__}",

            ]

            return runner_class, class_info



        def _get_io_mixin(self):

            """Determines which IO class to use



            Embedded logic based on the enum class



            Returns

            -------

            ABCMeta

                un-instantiated ioclass



            """

            if self.is_deepspeed:

                return_class = RunnerIOEnum.deepspeed.value

            elif self.is_horovod:

                return_class = RunnerIOEnum.horovod.value

            elif self.is_ddp:

                return_class = RunnerIOEnum.ddp.value

            else:

                return_class = RunnerIOEnum.base.value

            return return_class



        def _get_optimizer_mixin(self):

            """Determines which optimizer class to use



            Embedded logic based on the enum class



            Returns

            -------

            ABCMeta

                un-instantiated optimizer class



            """

            if self.oss:

                return_class = RunnerOptimizerEnum.oss.value

            else:

                return_class = RunnerOptimizerEnum.base.value

            return return_class



        def _get_distributed_mixin(self):

            """Determines which distributed class to use



            Embedded logic based on the enum class



            Returns

            -------

            ABCMeta

                un-instantiated distributed class



            """

            # if not gpu then fall to cpu single

            if not self.gpu:

                return_class = RunnerDistEnum.cpu.value

            # if gpu but no distributed then fall to single gpu

            elif self.gpu and (self.distributed is None):

                return_class = RunnerDistEnum.gpu.value

            elif self.gpu and (self.distributed is not None):

                return_class = RunnerDistEnum[self.distributed].value

            else:

                raise ValueError("Stoke -- Cannot map to a valid distributed class")

            return return_class



        def _get_fp16_mixin(self):

            """Determines which fp16 class to use



            Embedded logic based on the enum class



            Returns

            -------

            ABCMeta

                un-instantiated fp16 class



            """

            if self.fp16 is not None:

                return_class = RunnerFP16Enum[self.fp16].value

            else:

                return_class = RunnerFP16Enum.full.value

            return return_class



        def DataLoader(

            self,

            dataset: Dataset[T_co],

            shuffle: bool = False,

            sampler: Optional[Sampler[int]] = None,

            batch_sampler: Optional[Sampler[Sequence[int]]] = None,

            num_workers: int = 0,

            collate_fn: _collate_fn_t = None,

            pin_memory: bool = False,

            drop_last: bool = False,

            timeout: float = 0,

            worker_init_fn: _worker_init_fn_t = None,

            multiprocessing_context=None,

            generator=None,

            *,

            prefetch_factor: int = 2,

            persistent_workers: bool = False,

        ):

            """Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.



            Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called)

            and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus

            object is available which is post init. This could be disconnected from this class but it would require the

            user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and

            never handled



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            shuffle: bool, default: False

                set to ``True`` to have the data reshuffled at every epoch.

            sampler: Sampler or Iterable, default: None

                defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``

                implemented. If specified, :attr:`shuffle` must not be specified.

            batch_sampler: Sampler or Iterable, default: None:

                like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with

                :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.

            num_workers: int, default: 0

                how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.

            collate_fn: callable, optional:

                merges a list of samples to form a mini-batch of Tensor(s).  Used when using batched loading from a

                map-style dataset.

            pin_memory: bool, default: False:

                If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your

                data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,

                see the example below.

            drop_last: bool, default: False

                set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.

                If ``False`` and the size of dataset is not divisible by the batch size, then the last batch

                will be smaller.

            timeout: numeric, default: 0

                if positive, the timeout value for collecting a batch from workers. Should always be non-negative.

            worker_init_fn: callable, default: None

                If not ``None``, this will be called on each worker subprocess with the worker id

                (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.

            prefetch_factor: int, default: 2

                Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers

                samples prefetched across all workers.

            persistent_workers: bool, default: False

                If ``True``, the data loader will not shutdown the worker processes after a dataset has been

                consumed once. This allows to maintain the workers `Dataset` instances alive.



            Returns

            -------

            StokeDataLoader

                wrapped torch.utils.data.DataLoader object



            """

            # Check if forkserver is available for horovod and use

            if (

                num_workers > 0

                and hasattr(torch.multiprocessing, "_supports_context")

                and torch.multiprocessing._supports_context

                and "forkserver" in torch.multiprocessing.get_all_start_methods()

                and self.is_horovod

            ):

                multiprocessing_context = "forkserver"



            if self._verbose and self.gpu:

                print(f"Automatically handling moving model input data to GPU(s)...")

            # Forward the already known options from the Stoke status

            return StokeDataLoader(

                gpu=self.gpu,

                fp16=self.fp16,

                batch_size=self.batch_size,

                dataset=dataset,

                shuffle=shuffle,

                sampler=sampler,

                batch_sampler=batch_sampler,

                num_workers=num_workers,

                collate_fn=collate_fn,

                pin_memory=pin_memory,

                drop_last=drop_last,

                timeout=timeout,

                worker_init_fn=worker_init_fn,

                multiprocessing_context=multiprocessing_context,

                generator=generator,

                prefetch_factor=prefetch_factor,

                persistent_workers=persistent_workers,

            )



        def model(self, *args, **kwargs):

            """Wrapped model forward call



            Parameters

            ----------

            *args: list or tuple

                Additional arguments should be passed as keyword arguments

            **kwargs: dict, optional

                Extra arguments passed to the model forward call



            Returns

            -------

            model forward output



            """

            with self._runner.model_context:

                return self._model(*args, **kwargs)

                # return self.model_access(*args, **kwargs)



        def loss(self, *args, **kwargs):

            """Wrapped callable loss function call



            Handles internal logic of aggregating up the losses for single and multiple losses



            Parameters

            ----------

            *args: list or tuple

                Additional arguments should be passed as keyword arguments

            **kwargs: dict, optional

                Extra arguments passed to the loss function call(s)



            Returns

            -------

            outputs of callable loss function(s)



            """

            # TODO: WIP Handle multiple losses. Should support list/tuple of losses. Check non base PyTorch

            with self._runner.loss_context:

                if isinstance(self._loss, (list, tuple)):

                    loss = type(self._loss)(val(*args, **kwargs) for val in self._loss)

                    sync_loss = [self.detach_and_sync_loss(val) for val in loss]

                    self._last_step_loss = type(self._loss)(

                        val for idx, val in enumerate(sync_loss)

                    )

                    self._agg_loss = type(self._loss)(

                        self._agg_loss[idx] + val for idx, val in enumerate(sync_loss)

                    )

                    self._handle_ema_loss(loss=sync_loss)

                    if self.grad_accum > 1 and self.model_access.training:

                        loss = type(loss)(val / self.grad_accum for val in loss)

                else:

                    loss = self._loss(*args, **kwargs)

                    sync_loss = self.detach_and_sync_loss(loss)

                    self._last_step_loss = sync_loss

                    self._agg_loss += sync_loss

                    self._handle_ema_loss(loss=sync_loss)

                    # Handle grad accumulation by dividing by the accumulation steps

                    if self.grad_accum > 1 and self.model_access.training:

                        loss = loss / self.grad_accum

                return loss



        def _handle_ema_loss(self, loss: Union[float, List[float], Tuple[float]]):

            """Handles calculating the ema loss



            Parameters

            ----------

            loss: Union[float, List[float], Tuple[float]]

                current calculated loss list, tuple or float



            Returns

            -------

            None



            """

            self._rolling_loss_steps += 1

            if isinstance(loss, (list, tuple)):

                self._rolling_mean_loss = type(self._rolling_mean_loss)(

                    self._ema_loss(value=val, current_mean=self._rolling_mean_loss[idx])

                    for idx, val in enumerate(loss)

                )

            else:

                self._rolling_mean_loss = self._ema_loss(

                    value=loss, current_mean=self._rolling_mean_loss

                )



        def _ema_loss(self, value: float, current_mean: float):

            """Calculate the ema of the loss



            Parameters

            ----------

            value: float

                current loss value

            current_mean: float

                current mean value



            Returns

            -------

            current ema value: float



            """

            if self._rolling_loss_steps == 1:

                return value

            else:

                return (self._ema_weight * value) + (

                    (1.0 - self._ema_weight) * current_mean

                )



        def backward(

            self, loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

        ):

            """Wrapped backwards call



            Parameters

            ----------

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

                Callable loss function(s)



            Returns

            -------

            None



            """

            # Increment the grad counter

            self._grad_accum_counter += 1

            # Set the context based on the counter

            dist_cm = (

                nullcontext()

                if self._check_accum()

                else self._runner.grad_accum_context(self._model)

            )

            with dist_cm:

                self._runner.backward_call(

                    loss=loss, model=self.model_access, optimizer=self._optimizer

                )

            # Increment the number of total calls to backward (each backward to a loss is only considered 1)

            self._backward_steps += 1



        def step(self):

            """Wrapped step call



            Handles grad clipping internally



            Returns

            -------

            None



            """

            # Step the optimizer only if the modulo is zero

            if self._check_accum():

                if self._verbose and self.grad_accum > 0:

                    self.print(f"Gradient Accumulation Steps: {self.grad_accum}")

                # Clip if needed

                if self.grad_clip is not None:

                    self._runner.clip_grad(

                        self.grad_clip,

                        self._model if self.fully_sharded else self.model_access,

                        self._optimizer,

                        oss=self.oss,

                        horovod=self.is_horovod,

                        deepspeed=self.is_deepspeed,

                        fsdp=self.fully_sharded,

                    )

                # Handle the optimizer step

                step_cm = (

                    self._runner.step_context(self._optimizer)

                    if self.grad_clip is not None

                    else nullcontext()

                )

                with step_cm:

                    self._runner.step_call(

                        model=self.model_access, optimizer=self._optimizer

                    )

                # Reset for the accumulated step

                self._reset()

                # Increment the number of step calls to the optimizer

                self._optimizer_steps += 1

            # if deepspeed we need to step everytime as it handles the grad accumulation internally

            elif self.is_deepspeed:

                # Handle the optimizer step

                step_cm = (

                    self._runner.step_context(self._optimizer)

                    if self.grad_clip is not None

                    else nullcontext()

                )

                with step_cm:

                    self._runner.step_call(

                        model=self.model_access, optimizer=self._optimizer

                    )



        def _reset(self):

            """Resets the state post optimizer step call



            Returns

            -------

            None



            """

            if self._verbose:

                self.print("Resetting all grad/variables for next optimizer step")

            # Zero the grads if not deepspeed

            if not self.is_deepspeed:

                self.zero_grads()

            # Reset counter

            self._grad_accum_counter = 0

            # Reset agg loss -- single or mutiple losses

            self._agg_loss = self._set_loss_to_zero()



        def save(

            self,

            path: str,

            name: str = uuid4(),

            extension: str = "pt",

            create_directory: bool = True,

            extras: Optional[dict] = None,

        ):

            """Saves a model checkpoint using the correct backend interface



            Parameters

            ----------

            path: str

                path to directory to save the model checkpoint (prefer absolute paths over relative paths)

            name: str, default: uuid4()

                name used to save checkpoint file

            extension: str, default: '.pt'

                extension used to save PyTorch model checkpoint

            create_directory: bool, default: True

                flag to create the directory path if it doesn't exist

            extras: dict, default: None

                a dictionary of any extra things to save



            Returns

            -------

            path: str

                path to directory that the model checkpoint was saved

            tag: str

                full tag name the model checkpoint was saved as



            """

            out_path, tag = self._runner.save(

                model=self._model if self.fully_sharded else self.model_access,

                optimizer=self.optimizer,

                path=path,

                backward_step=self._backward_steps,

                grad_accum_step=self._grad_accum_counter,

                optimizer_step=self._optimizer_steps,

                name=name,

                scaler_dict=self.fp16_state_dict,

                extension=extension,

                create_directory=create_directory,

                extras=extras,

                status=self.status.status,

            )

            self.print(f"Successfully saved model checkpoint to {out_path}/{tag}")

            return out_path, tag



        def load(self, path: str, tag: str, strict: bool = True):

            """Loads a model checkpoint using the correct backend interface



            Parameters

            ----------

            path: str

                path to directory that the model checkpoint was saved (prefer absolute paths over relative paths)

            tag: str

                full tag name the model checkpoint was saved as

            strict: bool

                ignore non-matching keys



            Returns

            -------

            extras: dict, default: None

                a dictionary of any custom fields the user passed to the save function



            """

            # TODO: How to deal with mapping between backends? e.g. FP16 model back to FP32? Or multi-gpu to CPU?

            backward_step, grad_accum_step, optimizer_step, extras = self._runner.load(

                model=self._model if self.fully_sharded else self.model_access,

                optimizer=self.optimizer,

                gpu=self.gpu,

                path=path,

                tag=tag,

                scaler_dict_fn=self._load_fp16_state_dict_fn(),

                strict=strict,

            )

            # Reset values based on what was in the load dict

            self._backward_steps = backward_step

            self._grad_accum_counter = grad_accum_step

            self._optimizer_steps = optimizer_step

            self.print(f"Successfully loaded model checkpoint from {path}/{tag}")

            # Return the extras dict

            return extras



        def print_num_model_parameters(

            self, normalize: ParamNormalize = ParamNormalize.MILLION

        ):

            """



            Parameters

            ----------

            normalize: ParamNormalize, default: ParamNormalize.MILLION

                ParamNormalize choice for pretty print normalizing



            Returns

            -------

            None



            """

            self.print(

                f"Total Trainable Model Parameters: "

                f"{(self.num_model_parameters / normalize.value):.3f} {normalize.name}"

            )



        def detach_and_sync_loss(

            self,

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],

            device=None,

        ):

            """Shorthand method to detach and sync loss



            Maps to the runner function of the same name



            Parameters

            ----------

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

                current loss(es)

            device: default: None

                device to sync across



            Returns

            -------

            loss that is synced across devices and all_reduced w/ SUM



            """

            return self._runner.detach_and_sync_loss(loss=loss, device=device)



        def zero_grads(self):

            """Zeros the optimizer grads depending on the optimizer type



            Returns

            -------

            None



            """

            zero_optimizer_grads(

                optimizer=self._optimizer, apex=self.is_apex, horovod=self.is_horovod

            )



        def reset(self):

            """Public method for resetting the underlying stoke state



            Returns

            -------

            None



            """

            self._reset()



        def reset_tracking(self):

            """Public method for resetting all underlying stoke tracked variables



            Returns

            -------

            None



            """

            # Create some tracking vars

            self._grad_accum_counter = 0

            self._optimizer_steps = 0

            self._backward_steps = 0

            self._last_step_loss = self._set_loss_to_zero()

            self._agg_loss = self._set_loss_to_zero()

            self._rolling_mean_loss = self._set_loss_to_zero()

            self._rolling_loss_steps = 0



        def dump_model_parameter_info(self):

            """Dumps all parameter information for named parameters (shape, device, dtype)



            Returns

            -------

            None



            """

            self.print("Dumping all model parameter information to stdout....")

            for name, param in self.model_access.named_parameters():

                if param.requires_grad:

                    self.print(

                        f"Name: {name}, Shape: {param.shape}, "

                        f"Device: {param.device}, dtype: {param.dtype}"

                    )



        def _load_fp16_state_dict_fn(self):

            """Returns the function to load the sacler state dict



            Returns

            -------

            mp_state_dict_fn: Callable, default: None

                callable function to load the scaler state dict



            """

            mp_state_dict_fn = None

            if self.scaler is not None:

                if self.is_apex:

                    try:

                        from apex import amp



                        mp_state_dict_fn = amp.load_state_dict

                    except ImportError as e:

                        print(

                            e,

                            ": Stoke -- apex cannot be imported -- please install (https://github.com/NVIDIA/apex)",

                        )

                else:

                    mp_state_dict_fn = self.scaler.load_state_dict

            return mp_state_dict_fn



        def barrier(self):

            """Calls the underlying distributed barrier if available"""

            self._runner.barrier()



        @property

        def step_loss(self):

            """Gets the last step loss synced across device(s) (unscaled)"""

            return self._last_step_loss



        @property

        def model_access(self):

            """Interface for model access due to the different types between the DP, DDP, and SDDP implementations"""

            if isinstance(self._model, (DDP, DP, SDDP, FSDP)):

                return self._model.module

            else:

                return self._model



        @property

        def loss_access(self):

            """Gets loss tensor(s)"""

            return self._loss



        @property

        def optimizer(self):

            """Gets the optimizer"""

            return self._optimizer



        @property

        def scaler(self):

            """Gets the current scaler object"""

            return self._runner.scaler



        @property

        def fp16_state_dict(self):

            """Gets the fp16 state dict from various methods"""

            mp_state_dict = None

            if self.scaler is not None:

                if self.is_apex:

                    try:

                        from apex import amp



                        mp_state_dict = amp.state_dict()

                    except ImportError as e:

                        print(

                            e,

                            ": Stoke -- apex cannot be imported -- please install (https://github.com/NVIDIA/apex)",

                        )

                elif self.is_amp:

                    mp_state_dict = self.scaler.state_dict()

            return mp_state_dict



        @property

        def status(self):

            """Gets the StokeStatus object"""

            return self._status



        @property

        def batch_size(self):

            """Shortcut to batch size"""

            return self._status.batch_size



        @property

        def effective_batch_size(self):

            """Shortcut to effective batch size"""

            return self._status.effective_batch_size



        @property

        def grad_clip(self):

            """Shortcut to get grad clip"""

            return self._status.grad_clip



        @property

        def grad_accum(self):

            """Shortcut to get grad accumulation"""

            return self._status.grad_accum



        @property

        def gpu(self):

            """Shortcut to get GPU status"""

            return self._status.gpu



        @property

        def cuda(self):

            """Shortcut to get cuda status"""

            return self._status.cuda



        @property

        def nccl(self):

            """Shortcut to get nccl status"""

            return self._status.nccl



        @property

        def fp16(self):

            """Shortcut to get FP16 status"""

            return self._status.fp16



        @property

        def is_apex(self):

            """Returns if APEX is activated"""

            return self._status.is_fp16_apex



        @property

        def is_amp(self):

            """Returns if AMP is activated"""

            return self._status.is_fp16_amp



        @property

        def distributed(self):

            """Shortcut to distributed status"""

            return self._status.distributed



        @property

        def is_ddp(self):

            """Returns if DDP is activated"""

            return self._status.is_distributed_ddp



        @property

        def is_horovod(self):

            """Returns if Horovod is activated"""

            return self._status.is_distributed_horovod



        @property

        def is_deepspeed(self):

            """Returns if Deepspeed is acticated"""

            return self._status.is_distributed_deepspeed



        @property

        def oss(self):

            """Returns if Fairscale optimizer state sharding status"""

            return self._status.oss



        @property

        def sharded(self):

            """Returns if Fairscale sharded DDP status"""

            return self._status.sharded



        @property

        def fully_sharded(self):

            """Returns if Fairscale fully sharded DDP status"""

            return self._status.fully_sharded



        @property

        def world_size(self):

            """Shortcut to get world size"""

            return self._runner.world_size



        @property

        def rank(self):

            """Shortcut to get rank"""

            return self._runner.rank



        @property

        def amp_config(self):

            """Returns amp config or None based on amp state"""

            return self._status.amp_config if self.is_amp else None



        @property

        def apex_config(self):

            """Returns apex config or None based on apex state"""

            return self._status.apex_config if self.is_apex else None



        @property

        def ddp_config(self):

            """Returns ddp config or None based on ddp state"""

            return self._status.ddp_config if self.is_ddp else None



        @property

        def deepspeed_config(self):

            """Returns deepspeed config or None based on deepspeed state"""

            return self._status.deepspeed_config if self.is_deepspeed else None



        @property

        def oss_config(self):

            """Returns oss config or None based on ossstate"""

            return self._status.oss_config if self.oss else None



        @property

        def sddp_config(self):

            """Returns sddp config or None based on sddp state"""

            return self._status.sddp_config if self.sharded else None



        @property

        def fsdp_config(self):

            """Returns fsdp config or None based on fsdp state"""

            return self._status.fsdp_config if self.fully_sharded else None



        @property

        def horovod_config(self):

            """Returns horovod config or None based on horovod state"""

            return self._status.horovod_config if self.is_horovod else None



        @property

        def num_model_parameters(self):

            """Returns number of parameters that require gradients"""

            return sum(p.numel() for p in self.model_access.parameters() if p.requires_grad)



        @property

        def ema_loss(self):

            """Returns the current rolling mean loss"""

            return self._rolling_mean_loss

Instance variables

amp_config

Returns amp config or None based on amp state

apex_config

Returns apex config or None based on apex state

batch_size

Shortcut to batch size

cuda

Shortcut to get cuda status

ddp_config

Returns ddp config or None based on ddp state

deepspeed_config

Returns deepspeed config or None based on deepspeed state

distributed

Shortcut to distributed status

effective_batch_size

Shortcut to effective batch size

ema_loss

Returns the current rolling mean loss

fp16

Shortcut to get FP16 status

fp16_state_dict

Gets the fp16 state dict from various methods

fsdp_config

Returns fsdp config or None based on fsdp state

fully_sharded

Returns if Fairscale fully sharded DDP status

gpu

Shortcut to get GPU status

grad_accum

Shortcut to get grad accumulation

grad_clip

Shortcut to get grad clip

horovod_config

Returns horovod config or None based on horovod state

is_amp

Returns if AMP is activated

is_apex

Returns if APEX is activated

is_ddp

Returns if DDP is activated

is_deepspeed

Returns if Deepspeed is acticated

is_horovod

Returns if Horovod is activated

loss_access

Gets loss tensor(s)

model_access

Interface for model access due to the different types between the DP, DDP, and SDDP implementations

nccl

Shortcut to get nccl status

num_model_parameters

Returns number of parameters that require gradients

optimizer

Gets the optimizer

oss

Returns if Fairscale optimizer state sharding status

oss_config

Returns oss config or None based on ossstate

rank

Shortcut to get rank

scaler

Gets the current scaler object

sddp_config

Returns sddp config or None based on sddp state

sharded

Returns if Fairscale sharded DDP status

status

Gets the StokeStatus object

step_loss

Gets the last step loss synced across device(s) (unscaled)

world_size

Shortcut to get world size

Methods

DataLoader

def DataLoader(
    self,
    dataset: torch.utils.data.dataset.Dataset[+T_co],
    shuffle: bool = False,
    sampler: Union[torch.utils.data.sampler.Sampler[int], NoneType] = None,
    batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence[int]], NoneType] = None,
    num_workers: int = 0,
    collate_fn: Callable[[List[~T]], Any] = None,
    pin_memory: bool = False,
    drop_last: bool = False,
    timeout: float = 0,
    worker_init_fn: Callable[[int], NoneType] = None,
    multiprocessing_context=None,
    generator=None,
    *,
    prefetch_factor: int = 2,
    persistent_workers: bool = False
)

Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.

Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called) and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus object is available which is post init. This could be disconnected from this class but it would require the user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and never handled

Parameters:

Name Type Description Default
dataset Dataset dataset from which to load the data. None
shuffle bool, default: False set to True to have the data reshuffled at every epoch. None
sampler Sampler or Iterable, default: None defines the strategy to draw samples from the dataset. Can be any Iterable with __len__
implemented. If specified, :attr:shuffle must not be specified. None
batch_sampler Sampler or Iterable, default: None: like :attr:sampler, but returns a batch of indices at a time. Mutually exclusive with
:attr:batch_size, :attr:shuffle, :attr:sampler, and :attr:drop_last. None
num_workers int, default: 0 how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. None
collate_fn callable, optional: merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset. None
pin_memory bool, default: False: If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your
data elements are a custom type, or your :attr:collate_fn returns a batch that is a custom type,
see the example below. None
drop_last bool, default: False set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If False and the size of dataset is not divisible by the batch size, then the last batch
will be smaller. None
timeout numeric, default: 0 if positive, the timeout value for collecting a batch from workers. Should always be non-negative. None
worker_init_fn callable, default: None If not None, this will be called on each worker subprocess with the worker id
(an int in [0, num_workers - 1]) as input, after seeding and before data loading. None
prefetch_factor int, default: 2 Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers
samples prefetched across all workers. None
persistent_workers bool, default: False If True, the data loader will not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers Dataset instances alive. None

Returns:

Type Description
StokeDataLoader wrapped torch.utils.data.DataLoader object

??? example "View Source" def DataLoader(

            self,

            dataset: Dataset[T_co],

            shuffle: bool = False,

            sampler: Optional[Sampler[int]] = None,

            batch_sampler: Optional[Sampler[Sequence[int]]] = None,

            num_workers: int = 0,

            collate_fn: _collate_fn_t = None,

            pin_memory: bool = False,

            drop_last: bool = False,

            timeout: float = 0,

            worker_init_fn: _worker_init_fn_t = None,

            multiprocessing_context=None,

            generator=None,

            *,

            prefetch_factor: int = 2,

            persistent_workers: bool = False,

        ):

            """Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs.



            Shim is necessary for two reasons... to inject some horovod runtime configs (make sure forkserver is called)

            and to automatically handle device placement since the gpu/fp16 flags can't be determined until the StokeStatus

            object is available which is post init. This could be disconnected from this class but it would require the

            user to forward on device or fp16 configs which breaks the paradigm that the flags only need to be set and

            never handled



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            shuffle: bool, default: False

                set to ``True`` to have the data reshuffled at every epoch.

            sampler: Sampler or Iterable, default: None

                defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``

                implemented. If specified, :attr:`shuffle` must not be specified.

            batch_sampler: Sampler or Iterable, default: None:

                like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with

                :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.

            num_workers: int, default: 0

                how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.

            collate_fn: callable, optional:

                merges a list of samples to form a mini-batch of Tensor(s).  Used when using batched loading from a

                map-style dataset.

            pin_memory: bool, default: False:

                If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your

                data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,

                see the example below.

            drop_last: bool, default: False

                set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.

                If ``False`` and the size of dataset is not divisible by the batch size, then the last batch

                will be smaller.

            timeout: numeric, default: 0

                if positive, the timeout value for collecting a batch from workers. Should always be non-negative.

            worker_init_fn: callable, default: None

                If not ``None``, this will be called on each worker subprocess with the worker id

                (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.

            prefetch_factor: int, default: 2

                Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers

                samples prefetched across all workers.

            persistent_workers: bool, default: False

                If ``True``, the data loader will not shutdown the worker processes after a dataset has been

                consumed once. This allows to maintain the workers `Dataset` instances alive.



            Returns

            -------

            StokeDataLoader

                wrapped torch.utils.data.DataLoader object



            """

            # Check if forkserver is available for horovod and use

            if (

                num_workers > 0

                and hasattr(torch.multiprocessing, "_supports_context")

                and torch.multiprocessing._supports_context

                and "forkserver" in torch.multiprocessing.get_all_start_methods()

                and self.is_horovod

            ):

                multiprocessing_context = "forkserver"



            if self._verbose and self.gpu:

                print(f"Automatically handling moving model input data to GPU(s)...")

            # Forward the already known options from the Stoke status

            return StokeDataLoader(

                gpu=self.gpu,

                fp16=self.fp16,

                batch_size=self.batch_size,

                dataset=dataset,

                shuffle=shuffle,

                sampler=sampler,

                batch_sampler=batch_sampler,

                num_workers=num_workers,

                collate_fn=collate_fn,

                pin_memory=pin_memory,

                drop_last=drop_last,

                timeout=timeout,

                worker_init_fn=worker_init_fn,

                multiprocessing_context=multiprocessing_context,

                generator=generator,

                prefetch_factor=prefetch_factor,

                persistent_workers=persistent_workers,

            )

backward

def backward(
    self,
    loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
)

Wrapped backwards call

Parameters:

Name Type Description Default
loss Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] Callable loss function(s) None

Returns:

Type Description
None None

??? example "View Source" def backward(

            self, loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

        ):

            """Wrapped backwards call



            Parameters

            ----------

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

                Callable loss function(s)



            Returns

            -------

            None



            """

            # Increment the grad counter

            self._grad_accum_counter += 1

            # Set the context based on the counter

            dist_cm = (

                nullcontext()

                if self._check_accum()

                else self._runner.grad_accum_context(self._model)

            )

            with dist_cm:

                self._runner.backward_call(

                    loss=loss, model=self.model_access, optimizer=self._optimizer

                )

            # Increment the number of total calls to backward (each backward to a loss is only considered 1)

            self._backward_steps += 1

barrier

def barrier(
    self
)

Calls the underlying distributed barrier if available

??? example "View Source" def barrier(self):

            """Calls the underlying distributed barrier if available"""

            self._runner.barrier()

detach_and_sync_loss

def detach_and_sync_loss(
    self,
    loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
    device=None
)

Shorthand method to detach and sync loss

Maps to the runner function of the same name

Parameters:

Name Type Description Default
loss Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] current loss(es) None
device default: None device to sync across None

Returns:

Type Description
loss that is synced across devices and all_reduced w/ SUM None

??? example "View Source" def detach_and_sync_loss(

            self,

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],

            device=None,

        ):

            """Shorthand method to detach and sync loss



            Maps to the runner function of the same name



            Parameters

            ----------

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

                current loss(es)

            device: default: None

                device to sync across



            Returns

            -------

            loss that is synced across devices and all_reduced w/ SUM



            """

            return self._runner.detach_and_sync_loss(loss=loss, device=device)

dump_model_parameter_info

def dump_model_parameter_info(
    self
)

Dumps all parameter information for named parameters (shape, device, dtype)

Returns:

Type Description
None None

??? example "View Source" def dump_model_parameter_info(self):

            """Dumps all parameter information for named parameters (shape, device, dtype)



            Returns

            -------

            None



            """

            self.print("Dumping all model parameter information to stdout....")

            for name, param in self.model_access.named_parameters():

                if param.requires_grad:

                    self.print(

                        f"Name: {name}, Shape: {param.shape}, "

                        f"Device: {param.device}, dtype: {param.dtype}"

                    )

load

def load(
    self,
    path: str,
    tag: str,
    strict: bool = True
)

Loads a model checkpoint using the correct backend interface

Parameters:

Name Type Description Default
path str path to directory that the model checkpoint was saved (prefer absolute paths over relative paths) None
tag str full tag name the model checkpoint was saved as None
strict bool ignore non-matching keys None

Returns:

Type Description
dict, default: None a dictionary of any custom fields the user passed to the save function

??? example "View Source" def load(self, path: str, tag: str, strict: bool = True):

            """Loads a model checkpoint using the correct backend interface



            Parameters

            ----------

            path: str

                path to directory that the model checkpoint was saved (prefer absolute paths over relative paths)

            tag: str

                full tag name the model checkpoint was saved as

            strict: bool

                ignore non-matching keys



            Returns

            -------

            extras: dict, default: None

                a dictionary of any custom fields the user passed to the save function



            """

            # TODO: How to deal with mapping between backends? e.g. FP16 model back to FP32? Or multi-gpu to CPU?

            backward_step, grad_accum_step, optimizer_step, extras = self._runner.load(

                model=self._model if self.fully_sharded else self.model_access,

                optimizer=self.optimizer,

                gpu=self.gpu,

                path=path,

                tag=tag,

                scaler_dict_fn=self._load_fp16_state_dict_fn(),

                strict=strict,

            )

            # Reset values based on what was in the load dict

            self._backward_steps = backward_step

            self._grad_accum_counter = grad_accum_step

            self._optimizer_steps = optimizer_step

            self.print(f"Successfully loaded model checkpoint from {path}/{tag}")

            # Return the extras dict

            return extras

loss

def loss(
    self,
    *args,
    **kwargs
)

Wrapped callable loss function call

Handles internal logic of aggregating up the losses for single and multiple losses

Parameters:

Name Type Description Default
*args list or tuple Additional arguments should be passed as keyword arguments None
**kwargs dict Extra arguments passed to the loss function call(s) None

Returns:

Type Description
outputs of callable loss function(s) None

??? example "View Source" def loss(self, args, *kwargs):

            """Wrapped callable loss function call



            Handles internal logic of aggregating up the losses for single and multiple losses



            Parameters

            ----------

            *args: list or tuple

                Additional arguments should be passed as keyword arguments

            **kwargs: dict, optional

                Extra arguments passed to the loss function call(s)



            Returns

            -------

            outputs of callable loss function(s)



            """

            # TODO: WIP Handle multiple losses. Should support list/tuple of losses. Check non base PyTorch

            with self._runner.loss_context:

                if isinstance(self._loss, (list, tuple)):

                    loss = type(self._loss)(val(*args, **kwargs) for val in self._loss)

                    sync_loss = [self.detach_and_sync_loss(val) for val in loss]

                    self._last_step_loss = type(self._loss)(

                        val for idx, val in enumerate(sync_loss)

                    )

                    self._agg_loss = type(self._loss)(

                        self._agg_loss[idx] + val for idx, val in enumerate(sync_loss)

                    )

                    self._handle_ema_loss(loss=sync_loss)

                    if self.grad_accum > 1 and self.model_access.training:

                        loss = type(loss)(val / self.grad_accum for val in loss)

                else:

                    loss = self._loss(*args, **kwargs)

                    sync_loss = self.detach_and_sync_loss(loss)

                    self._last_step_loss = sync_loss

                    self._agg_loss += sync_loss

                    self._handle_ema_loss(loss=sync_loss)

                    # Handle grad accumulation by dividing by the accumulation steps

                    if self.grad_accum > 1 and self.model_access.training:

                        loss = loss / self.grad_accum

                return loss

model

def model(
    self,
    *args,
    **kwargs
)

Wrapped model forward call

Parameters:

Name Type Description Default
*args list or tuple Additional arguments should be passed as keyword arguments None
**kwargs dict Extra arguments passed to the model forward call None

Returns:

Type Description
model forward output None

??? example "View Source" def model(self, args, *kwargs):

            """Wrapped model forward call



            Parameters

            ----------

            *args: list or tuple

                Additional arguments should be passed as keyword arguments

            **kwargs: dict, optional

                Extra arguments passed to the model forward call



            Returns

            -------

            model forward output



            """

            with self._runner.model_context:

                return self._model(*args, **kwargs)

                # return self.model_access(*args, **kwargs)

print

def print(
    self,
    msg: Union[str, List[str]],
    single_line: bool = False
)

Wraps the runners print device and forces print on the _info_rank attribute(s)

Parameters:

Name Type Description Default
msg str message to print None
single_line bool, default: False if iterable print all on one line space and comma separated None

Returns:

Type Description
None None

??? example "View Source" def print(self, msg: Union[str, List[str]], single_line: bool = False):

            """Wraps the runners print device and forces print on the _info_rank attribute(s)



            Parameters

            ----------

            msg: str

                message to print

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            self._runner.print_device(

                msg=msg, rank=self._info_rank, single_line=single_line

            )
def print_ema_loss(
    self,
    prepend_msg: str = 'Current EMA Loss',
    single_line: bool = False
)

Prints the current ema loss synced across all devices

Handles single or multiple losses. Prints only on devices specified by self._info_rank

Parameters:

Name Type Description Default
prepend_msg str, default: "Current EMA Loss" message prepend to print None
single_line bool, default: False if iterable print all on one line space and comma separated None

Returns:

Type Description
None None

??? example "View Source" def print_ema_loss(

            self, prepend_msg: str = "Current EMA Loss", single_line: bool = False

        ):

            """Prints the current ema loss synced across all devices



            Handles single or multiple losses. Prints only on devices specified by self._info_rank



            Parameters

            ----------

            prepend_msg: str, default: "Current EMA Loss"

                message prepend to print

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            if isinstance(self._rolling_mean_loss, (list, tuple)):

                print_vals = [

                    f"{prepend_msg} {idx}: {val:.3f}"

                    for idx, val in enumerate(self._rolling_mean_loss)

                ]

                self.print(print_vals, single_line=single_line)

            else:

                self.print(f"{prepend_msg}: {self._rolling_mean_loss:.3f}")
def print_mean_accumulated_synced_loss(
    self,
    prepend_msg: str = 'Mean Accumulated & Synced Loss',
    pre_backwards: bool = True,
    single_line: bool = False
)

Prints the mean accumulated and device synced loss only after the grad accumulation step

Handles single or multiple losses. Prints only on devices specified by self._info_rank

Parameters:

Name Type Description Default
prepend_msg str, default: "Mean Accumulated & Synced Loss" message prepend to print None
pre_backwards bool, default: True if being called pre backward step None
single_line bool, default: False if iterable print all on one line space and comma separated None

Returns:

Type Description
None None

??? example "View Source" def print_mean_accumulated_synced_loss(

            self,

            prepend_msg: str = "Mean Accumulated & Synced Loss",

            pre_backwards: bool = True,

            single_line: bool = False,

        ):

            """Prints the mean accumulated and device synced loss only after the grad accumulation step



            Handles single or multiple losses. Prints only on devices specified by self._info_rank



            Parameters

            ----------

            prepend_msg: str, default: "Mean Accumulated & Synced Loss"

                message prepend to print

            pre_backwards: bool, default: True

                if being called pre backward step

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            check_fn = self._check_pre_accum if pre_backwards else self._check_accum

            if check_fn():

                if isinstance(self._agg_loss, (list, tuple)):

                    print_vals = self._scale_agg_loss()

                    self.print(print_vals, single_line=single_line)

                else:

                    self.print(f"{prepend_msg}: {self._scale_agg_loss():.3f}")
def print_num_model_parameters(
    self,
    normalize: stoke.utils.ParamNormalize = <ParamNormalize.MILLION: 1000000.0>
)

Parameters:

Name Type Description Default
normalize ParamNormalize, default: ParamNormalize.MILLION ParamNormalize choice for pretty print normalizing None

Returns:

Type Description
None None

??? example "View Source" def print_num_model_parameters(

            self, normalize: ParamNormalize = ParamNormalize.MILLION

        ):

            """



            Parameters

            ----------

            normalize: ParamNormalize, default: ParamNormalize.MILLION

                ParamNormalize choice for pretty print normalizing



            Returns

            -------

            None



            """

            self.print(

                f"Total Trainable Model Parameters: "

                f"{(self.num_model_parameters / normalize.value):.3f} {normalize.name}"

            )
def print_on_devices(
    self,
    msg: Union[str, List[str]],
    rank: Union[int, List[int], NoneType] = 0
)

Wraps runner print interface for shorter semantics

Parameters:

Name Type Description Default
msg str message to print None
rank Union[int, List[int]], default: 0 which ranks to print on None

Returns:

Type Description
None None

??? example "View Source" def print_on_devices(

            self, msg: Union[str, List[str]], rank: Optional[Union[int, List[int]]] = 0

        ):

            """Wraps runner print interface for shorter semantics



            Parameters

            ----------

            msg: str

                message to print

            rank: Union[int, List[int]], default: 0

                which ranks to print on



            Returns

            -------

            None



            """

            self._runner.print_device(msg=msg, rank=rank)
def print_synced_loss(
    self,
    loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
    prepend_msg: str = 'Step Synced Loss',
    device=None,
    single_line: bool = False
)

Prints a device synced loss at a single step

Handles single or multiple losses. Prints only on devices specified by self._info_rank

Parameters:

Name Type Description Default
loss Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] current loss(es) on the device None
prepend_msg str, default: "Step Synced Loss" message prepend to print None
device default: None specify the device to place the synced loss on (defaults to same device) same
single_line bool, default: False if iterable print all on one line space and comma separated None

Returns:

Type Description
None None

??? example "View Source" def print_synced_loss(

            self,

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],

            prepend_msg: str = "Step Synced Loss",

            device=None,

            single_line: bool = False,

        ):

            """Prints a device synced loss at a single step



            Handles single or multiple losses. Prints only on devices specified by self._info_rank



            Parameters

            ----------

            loss: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]

                current loss(es) on the device

            prepend_msg: str, default: "Step Synced Loss"

                message prepend to print

            device: default: None

                specify the device to place the synced loss on (defaults to same device)

            single_line: bool, default: False

                if iterable print all on one line space and comma separated



            Returns

            -------

            None



            """

            printable_loss = self.detach_and_sync_loss(loss, device)

            if isinstance(printable_loss, (list, tuple)):

                print_vals = [

                    f"{prepend_msg} {idx}: {val * self.grad_accum:.3f}"

                    for idx, val in enumerate(printable_loss)

                ]

                self.print(print_vals, single_line=single_line)

            else:

                self.print(msg=f"{prepend_msg}: {printable_loss * self.grad_accum:.3f}")

reset

def reset(
    self
)

Public method for resetting the underlying stoke state

Returns:

Type Description
None None

??? example "View Source" def reset(self):

            """Public method for resetting the underlying stoke state



            Returns

            -------

            None



            """

            self._reset()

reset_ema

def reset_ema(
    self
)

Used to reset the current state of the rolling mean loss

Returns:

Type Description
None None

??? example "View Source" def reset_ema(self):

            """Used to reset the current state of the rolling mean loss



            Returns

            -------

            None



            """

            self._rolling_mean_loss = self._set_loss_to_zero()

            self._rolling_loss_steps = 0

reset_tracking

def reset_tracking(
    self
)

Public method for resetting all underlying stoke tracked variables

Returns:

Type Description
None None

??? example "View Source" def reset_tracking(self):

            """Public method for resetting all underlying stoke tracked variables



            Returns

            -------

            None



            """

            # Create some tracking vars

            self._grad_accum_counter = 0

            self._optimizer_steps = 0

            self._backward_steps = 0

            self._last_step_loss = self._set_loss_to_zero()

            self._agg_loss = self._set_loss_to_zero()

            self._rolling_mean_loss = self._set_loss_to_zero()

            self._rolling_loss_steps = 0

save

def save(
    self,
    path: str,
    name: str = UUID('1bec68f4-7df7-48d2-a526-14685e92f54f'),
    extension: str = 'pt',
    create_directory: bool = True,
    extras: Union[dict, NoneType] = None
)

Saves a model checkpoint using the correct backend interface

Parameters:

Name Type Description Default
path str path to directory to save the model checkpoint (prefer absolute paths over relative paths) None
name str, default: uuid4() name used to save checkpoint file None
extension str, default: '.pt' extension used to save PyTorch model checkpoint None
create_directory bool, default: True flag to create the directory path if it doesn't exist None
extras dict, default: None a dictionary of any extra things to save None

Returns:

Type Description
str path to directory that the model checkpoint was saved

??? example "View Source" def save(

            self,

            path: str,

            name: str = uuid4(),

            extension: str = "pt",

            create_directory: bool = True,

            extras: Optional[dict] = None,

        ):

            """Saves a model checkpoint using the correct backend interface



            Parameters

            ----------

            path: str

                path to directory to save the model checkpoint (prefer absolute paths over relative paths)

            name: str, default: uuid4()

                name used to save checkpoint file

            extension: str, default: '.pt'

                extension used to save PyTorch model checkpoint

            create_directory: bool, default: True

                flag to create the directory path if it doesn't exist

            extras: dict, default: None

                a dictionary of any extra things to save



            Returns

            -------

            path: str

                path to directory that the model checkpoint was saved

            tag: str

                full tag name the model checkpoint was saved as



            """

            out_path, tag = self._runner.save(

                model=self._model if self.fully_sharded else self.model_access,

                optimizer=self.optimizer,

                path=path,

                backward_step=self._backward_steps,

                grad_accum_step=self._grad_accum_counter,

                optimizer_step=self._optimizer_steps,

                name=name,

                scaler_dict=self.fp16_state_dict,

                extension=extension,

                create_directory=create_directory,

                extras=extras,

                status=self.status.status,

            )

            self.print(f"Successfully saved model checkpoint to {out_path}/{tag}")

            return out_path, tag

step

def step(
    self
)

Wrapped step call

Handles grad clipping internally

Returns:

Type Description
None None

??? example "View Source" def step(self):

            """Wrapped step call



            Handles grad clipping internally



            Returns

            -------

            None



            """

            # Step the optimizer only if the modulo is zero

            if self._check_accum():

                if self._verbose and self.grad_accum > 0:

                    self.print(f"Gradient Accumulation Steps: {self.grad_accum}")

                # Clip if needed

                if self.grad_clip is not None:

                    self._runner.clip_grad(

                        self.grad_clip,

                        self._model if self.fully_sharded else self.model_access,

                        self._optimizer,

                        oss=self.oss,

                        horovod=self.is_horovod,

                        deepspeed=self.is_deepspeed,

                        fsdp=self.fully_sharded,

                    )

                # Handle the optimizer step

                step_cm = (

                    self._runner.step_context(self._optimizer)

                    if self.grad_clip is not None

                    else nullcontext()

                )

                with step_cm:

                    self._runner.step_call(

                        model=self.model_access, optimizer=self._optimizer

                    )

                # Reset for the accumulated step

                self._reset()

                # Increment the number of step calls to the optimizer

                self._optimizer_steps += 1

            # if deepspeed we need to step everytime as it handles the grad accumulation internally

            elif self.is_deepspeed:

                # Handle the optimizer step

                step_cm = (

                    self._runner.step_context(self._optimizer)

                    if self.grad_clip is not None

                    else nullcontext()

                )

                with step_cm:

                    self._runner.step_call(

                        model=self.model_access, optimizer=self._optimizer

                    )

zero_grads

def zero_grads(
    self
)

Zeros the optimizer grads depending on the optimizer type

Returns:

Type Description
None None

??? example "View Source" def zero_grads(self):

            """Zeros the optimizer grads depending on the optimizer type



            Returns

            -------

            None



            """

            zero_optimizer_grads(

                optimizer=self._optimizer, apex=self.is_apex, horovod=self.is_horovod

            )

StokeOptimizer

class StokeOptimizer(
    /,
    *args,
    **kwargs
)

Attributes

Name Type Description Default
optimizer Type[torch.optim.Optimizer] un-instantiated torch.optim.Optimizer class None
optimizer_kwargs Dict any keyword args to be unrolled into the optimizer at instantiation time None

??? example "View Source" class StokeOptimizer(TypedDict):

        """Stoke optimizer wrapper class



        Given all the different backends and extensions the optimizer might need to be instantiated in a different way

        thus this typed dict holds the configuration without instantiation



        Attributes

        ----------

        optimizer: Type[torch.optim.Optimizer]

            un-instantiated torch.optim.Optimizer class

        optimizer_kwargs: Dict

            any keyword args to be unrolled into the optimizer at instantiation time



        """



        optimizer: Type[torch.optim.Optimizer]

        optimizer_kwargs: Dict

Ancestors (in MRO)

  • builtins.dict

Methods

clear

def clear(
    ...
)

D.clear() -> None. Remove all items from D.

copy

def copy(
    ...
)

D.copy() -> a shallow copy of D

fromkeys

def fromkeys(
    iterable,
    value=None,
    /
)

Create a new dictionary with keys from iterable and values set to value.

get

def get(
    self,
    key,
    default=None,
    /
)

Return the value for key if key is in the dictionary, else default.

items

def items(
    ...
)

D.items() -> a set-like object providing a view on D's items

keys

def keys(
    ...
)

D.keys() -> a set-like object providing a view on D's keys

pop

def pop(
    ...
)

D.pop(k[,d]) -> v, remove specified key and return the corresponding value.

If key is not found, d is returned if given, otherwise KeyError is raised

popitem

def popitem(
    self,
    /
)

Remove and return a (key, value) pair as a 2-tuple.

Pairs are returned in LIFO (last-in, first-out) order. Raises KeyError if the dict is empty.

setdefault

def setdefault(
    self,
    key,
    default=None,
    /
)

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

update

def update(
    ...
)

D.update([E, ]**F) -> None. Update D from dict/iterable E and F.

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

values

def values(
    ...
)

D.values() -> an object providing a view on D's values

Back to top