Skip to content

Module stoke.utils

Stoke utility functions

None

??? example "View Source" # -- coding: utf-8 --

    # Copyright FMR LLC <opensource@fidelity.com>

    # SPDX-License-Identifier: Apache-2.0



    """Stoke utility functions"""



    import os

    from enum import Enum

    from typing import Any, Callable, List, Tuple, TypeVar, Union



    import torch

    from fairscale.optim.oss import OSS



    # Taken from torch/utils/data/dataloader

    T_co = TypeVar("T_co", covariant=True)

    T = TypeVar("T")



    # Taken from torch/utils/data/dataloader

    _worker_init_fn_t = Callable[[int], None]

    # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`,

    # but there is currently no way to have that

    # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.

    # See https://github.com/python/mypy/issues/3737.

    _collate_fn_t = Callable[[List[T]], Any]





    class ParamNormalize(Enum):

        """Normalization enum for total number of model parameters used to help with a pretty print"""



        THOUSAND = 1e3

        MILLION = 1e6

        BILLION = 1e9

        TRILLION = 1e12





    def zero_optimizer_grads(

        optimizer: Union[torch.optim.Optimizer, OSS],

        apex: bool = False,

        horovod: bool = False,

    ):

        """Zeros grads depending on if it is a base Torch optimizer or a Fused version from APEX



        Parameters

        ----------

        optimizer: torch.optim.Optimizer

            current optimizer object

        apex: bool, default: False

            if apex is active

        horovod: bool, default: False

            if horovod is active

        Returns

        -------

        None



        """

        if (optimizer.__class__.__name__.find("Fused") == -1) and not apex and not horovod:

            optimizer.zero_grad(set_to_none=True)

        else:

            optimizer.zero_grad()





    def unrolled_print(msg: Union[str, List[str], Tuple[str]], single_line: bool = False):

        """Prints the msg if it's a string or iterable of strings



        Parameters

        ----------

        msg: Union[str, List[str], Tuple[str]

            string(s) to print

        single_line: bool, default: False

            if iterable print all on one line space and comma separated



        Returns

        -------

        None



        """

        if isinstance(msg, (list, tuple)):

            if single_line:

                msg = type(msg)(

                    f"Stoke -- {val}" if idx == 0 else f"{val}"

                    for idx, val in enumerate(msg)

                )

            else:

                msg = type(msg)(f"Stoke -- {val}" for idx, val in enumerate(msg))

            print(*msg, sep=", " if single_line else "\n")

        else:

            print(f"Stoke -- {msg}")





    def make_folder(path: str):

        """



        Parameters

        ----------

        path: str

            path to write



        Returns

        -------



        """

        # Make the folder if it doesn't exist

        if not os.path.isdir(path):

            os.makedirs(path, exist_ok=True)

Variables

T
T_co

Functions

make_folder

def make_folder(
    path: str
)

Parameters:

Name Type Description Default
path str path to write None

??? example "View Source" def make_folder(path: str):

        """



        Parameters

        ----------

        path: str

            path to write



        Returns

        -------



        """

        # Make the folder if it doesn't exist

        if not os.path.isdir(path):

            os.makedirs(path, exist_ok=True)

unrolled_print

def unrolled_print(
    msg: Union[str, List[str], Tuple[str]],
    single_line: bool = False
)

Prints the msg if it's a string or iterable of strings

Parameters:

Name Type Description Default
msg Union[str, List[str], Tuple[str] string(s) to print None
single_line bool, default: False if iterable print all on one line space and comma separated None

Returns:

Type Description
None None

??? example "View Source" def unrolled_print(msg: Union[str, List[str], Tuple[str]], single_line: bool = False):

        """Prints the msg if it's a string or iterable of strings



        Parameters

        ----------

        msg: Union[str, List[str], Tuple[str]

            string(s) to print

        single_line: bool, default: False

            if iterable print all on one line space and comma separated



        Returns

        -------

        None



        """

        if isinstance(msg, (list, tuple)):

            if single_line:

                msg = type(msg)(

                    f"Stoke -- {val}" if idx == 0 else f"{val}"

                    for idx, val in enumerate(msg)

                )

            else:

                msg = type(msg)(f"Stoke -- {val}" for idx, val in enumerate(msg))

            print(*msg, sep=", " if single_line else "\n")

        else:

            print(f"Stoke -- {msg}")

zero_optimizer_grads

def zero_optimizer_grads(
    optimizer: Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS],
    apex: bool = False,
    horovod: bool = False
)

Zeros grads depending on if it is a base Torch optimizer or a Fused version from APEX

Parameters:

Name Type Description Default
optimizer torch.optim.Optimizer current optimizer object None
apex bool, default: False if apex is active None
horovod bool, default: False if horovod is active None

Returns:

Type Description
None None

??? example "View Source" def zero_optimizer_grads(

        optimizer: Union[torch.optim.Optimizer, OSS],

        apex: bool = False,

        horovod: bool = False,

    ):

        """Zeros grads depending on if it is a base Torch optimizer or a Fused version from APEX



        Parameters

        ----------

        optimizer: torch.optim.Optimizer

            current optimizer object

        apex: bool, default: False

            if apex is active

        horovod: bool, default: False

            if horovod is active

        Returns

        -------

        None



        """

        if (optimizer.__class__.__name__.find("Fused") == -1) and not apex and not horovod:

            optimizer.zero_grad(set_to_none=True)

        else:

            optimizer.zero_grad()

Classes

ParamNormalize

class ParamNormalize(
    /,
    *args,
    **kwargs
)

??? example "View Source" class ParamNormalize(Enum):

        """Normalization enum for total number of model parameters used to help with a pretty print"""



        THOUSAND = 1e3

        MILLION = 1e6

        BILLION = 1e9

        TRILLION = 1e12

Ancestors (in MRO)

  • enum.Enum

Class variables

BILLION
MILLION
THOUSAND
TRILLION
name
value
Back to top