Skip to content

Module stoke.data

Handles any data (e.g. loader, sampler, etc.) related classes

None

??? example "View Source" # -- coding: utf-8 --

    # Copyright FMR LLC <opensource@fidelity.com>

    # SPDX-License-Identifier: Apache-2.0



    """Handles any data (e.g. loader, sampler, etc.) related classes"""



    import itertools

    from math import ceil

    from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union



    import horovod.torch as hvd

    import numpy as np

    import torch

    import torch.distributed as dist

    from torch.utils.data import DataLoader as DL

    from torch.utils.data import Dataset

    from torch.utils.data.distributed import Sampler



    from stoke.status import DistributedOptions, FP16Options

    from stoke.utils import T_co, _collate_fn_t, _worker_init_fn_t





    class StokeDataLoader(DL):

        """Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs



        Attributes

        ----------

        _gpu: bool

        _fp16: Optional[FP16Options]



        See Also

        --------

        torch.utils.data.DataLoader: base DataLoader class that this inherits from (check for all attributes)



        """



        def __init__(

            self,

            gpu: bool,

            fp16: Optional[FP16Options],

            dataset: Dataset[T_co],

            batch_size: Optional[int] = 1,

            shuffle: bool = False,

            sampler: Optional[Sampler[int]] = None,

            batch_sampler: Optional[Sampler[Sequence[int]]] = None,

            num_workers: int = 0,

            collate_fn: _collate_fn_t = None,

            pin_memory: bool = False,

            drop_last: bool = False,

            timeout: float = 0,

            worker_init_fn: _worker_init_fn_t = None,

            multiprocessing_context=None,

            generator=None,

            *,

            prefetch_factor: int = 2,

            persistent_workers: bool = False,

        ):

            """Maps to torch.utils.data.DataLoader __init__



            Shim is necessary to automatically handle device placement since the gpu/fp16 flags can't be

            determined until the StokeStatus object is available which is post init. This could be disconnected from

            this class but it would require the user to forward on device or fp16 configs which breaks the

            paradigm that the flags only need to be set and never handled



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            batch_size: int, default: 1

                how many samples per batch to load .

            shuffle: bool, default: False

                set to ``True`` to have the data reshuffled at every epoch.

            sampler: Sampler or Iterable, default: None

                defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``

                implemented. If specified, :attr:`shuffle` must not be specified.

            batch_sampler: Sampler or Iterable, default: None:

                like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with

                :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.

            num_workers: int, default: 0

                how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.

            collate_fn: callable, optional:

                merges a list of samples to form a mini-batch of Tensor(s).  Used when using batched loading from a

                map-style dataset.

            pin_memory: bool, default: False:

                If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your

                data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,

                see the example below.

            drop_last: bool, default: False

                set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.

                If ``False`` and the size of dataset is not divisible by the batch size, then the last batch

                will be smaller.

            timeout: numeric, default: 0

                if positive, the timeout value for collecting a batch from workers. Should always be non-negative.

            worker_init_fn: callable, default: None

                If not ``None``, this will be called on each worker subprocess with the worker id

                (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.

            prefetch_factor: int, default: 2

                Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers

                samples prefetched across all workers.

            persistent_workers: bool, default: False

                If ``True``, the data loader will not shutdown the worker processes after a dataset has been

                consumed once. This allows to maintain the workers `Dataset` instances alive.



            Returns

            -------

            StokeDataLoader

                wrapped torch.utils.data.DataLoader object



            """

            # Call super init for the actual torch DataLoader

            super(StokeDataLoader, self).__init__(

                dataset=dataset,

                batch_size=batch_size,

                shuffle=shuffle,

                sampler=sampler,

                batch_sampler=batch_sampler,

                num_workers=num_workers,

                collate_fn=collate_fn,

                pin_memory=pin_memory,

                drop_last=drop_last,

                timeout=timeout,

                worker_init_fn=worker_init_fn,

                multiprocessing_context=multiprocessing_context,

                generator=generator,

                prefetch_factor=prefetch_factor,

                persistent_workers=persistent_workers,

            )

            self._gpu = gpu

            self._fp16 = fp16



        def __iter__(self):

            """Underlying iter of the DataLoader that yields samples



            Wrap the base __iter__ with a call to place on the device if flagged



            Yields

            ------

            Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]

                data placed on the correct device



            """

            # Iterate using the base class iter but override the yield by pushing to device prior if gpu flag is true

            for val in super().__iter__():

                yield val if not self._gpu else self._place_data_on_gpu(val)



        def _place_data_on_gpu(

            self,

            data: Union[

                torch.Tensor,

                List[torch.Tensor],

                Tuple[torch.Tensor],

                Dict[str, torch.Tensor],

            ],

        ):

            """Determine data structure and then place on the correct device (cast in the context of deepspeed FP16 as it

            wants half dtype as input)



            Parameters

            ----------

            data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]

                current data coming from the underlying __iter__



            Returns

            -------

            data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]

                data moved to the correct device



            """

            if isinstance(data, torch.Tensor):

                # TODO: Check if one of the APEX version needs a cast too?

                # Move to the correct cuda device w/ the correct type -- deepspeed FP16 requires a cast to half if fp16

                if self._fp16 == "deepspeed":

                    return data.to(device="cuda", dtype=torch.half)

                else:

                    return data.to(device="cuda", dtype=data.dtype)

            elif isinstance(data, (list, tuple)):

                return type(data)(self._place_data_on_gpu(data=val) for val in data)

            elif isinstance(data, dict):

                return {k: self._place_data_on_gpu(v) for k, v in data.items()}

            elif ~(hasattr(data, "to")):

                return data

            else:

                raise TypeError(

                    f"Stoke -- Unsupported data type passed to _place_data_on_gpu "

                    f"(torch.Tensor, tuple, list, dict), currently {type(data)}"

                )





    class BucketedDistributedSampler(Sampler[T_co]):

        """Sampler that buckets samples by sorted_idx and then randomly samples from a specific bucket to prevent excess

        padding leading to wasted computation



        Borrowing heavily from the base DistributedSampler

        https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler



        Attributes

        ----------

        num_replicas: int, default: None

            number of replicas

        rank: int, default: None

            current device rank

        epoch: int

            current training epoch

        drop_last: bool, default: False

            whether to drop last set of samples that don't fit into a batch

        shuffle: bool, default: True

            flag to shuffle dataset

        seed: int, default: 0

            seed to use for generators

        buckets: int

            number of buckets to break the dataset into

        sorted_n_samples: list

            sorted list of samples by the characteristic to bucket by (e.g. seq len)

        batch_size: int

            batch size that will be used (needed to make sure slices are correct)

        allow_bucket_overlap: bool, default: False

            allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into

            an un-bucketed batch

        slice_size: int

            computed from batch size and number of replicas

        num_samples_per_bucket: int

            computed value that represents the number of samples in a single bucket

        num_slices_per_bucket: int

            computed value that represents the number of slices available in a bucket

        bucket_idx: list

            computed value that make a contiguous list of indices in each bucket

        rounded_num_samples_per_bucket: int

            computed value post round for number of samples in a single bucket

        rounded_num_samples_per_replica: int

            computed value post round for number of slices available in a bucket



        """



        def __init__(

            self,

            dataset: Dataset,

            buckets: int,

            batch_size: int,

            sorted_idx: List,

            backend: DistributedOptions,

            allow_bucket_overlap: bool = False,

            num_replicas: Optional[int] = None,

            rank: Optional[int] = None,

            shuffle: bool = True,

            seed: int = 0,

            drop_last: bool = False,

            info_rank: int = 0,

        ) -> None:

            """Init for BucketedDistributedSampler



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            buckets: int

                number of buckets to break the dataset into

            batch_size: int

                batch size that will be used (needed to make sure slices are correct)

            sorted_idx: list

                sorted list of samples by the characteristic to bucket by (e.g. seq le

            backend: DistributedOptions

                which backend is being used (as rank, world size, etc. need to be used)

            allow_bucket_overlap: bool, default: False

                allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into

                an un-bucketed batch

            num_replicas: int, default: None

                number of replicas

            rank: int, default: None

                current device rank

            shuffle: bool, default: True

                flag to shuffle dataset

            seed: int, default: 0

                seed to use for generators

            drop_last: bool, default: False

                whether to drop last set of samples that don't fit into a

            info_rank: int, default: 0

                which device to print information on



            """

            # If the backend isnt DDP there needs to be an additional import

            num_replicas, rank = self._conditional_distributed(

                backend=backend, num_replicas=num_replicas, rank=rank

            )

            self.num_replicas = num_replicas

            self.rank = rank

            self.epoch = 0

            self.drop_last = drop_last

            self.shuffle = shuffle

            self.seed = seed

            self.buckets = buckets

            self.sorted_n_samples = sorted_idx

            # Batch size is needed here so a contiguous iter of buckets can be formed

            self.batch_size = batch_size

            # This is a flag to batch up the dropped samples (that would be 'wasted') if drop_last is flagged

            self.allow_bucket_overlap = allow_bucket_overlap

            # Calculate the size of each slice that will be indexed across the replicas

            self.slice_size = self.batch_size * self.num_replicas

            # Calculate the size of the buckets (rounded or not based on drop last)

            self.num_samples_per_bucket = self._get_size(

                len(dataset), self.buckets, self.drop_last

            )

            # Calculate the number of slices per bucket

            self.num_slices_per_bucket = self._get_size(

                self.num_samples_per_bucket, self.slice_size, self.drop_last

            )

            if self.num_samples_per_bucket < self.slice_size:

                raise ValueError(

                    f"Stoke -- Resulting number of slices (batch * replicas) per bucket "

                    f"({self.num_samples_per_bucket}) is less than the batch size "

                    f"({self.batch_size})"

                )

            if self.num_slices_per_bucket < 2:

                raise ValueError(

                    f"Stoke -- Number of slices per bucket {self.num_slices_per_bucket} is less than 2 "

                    f"which is not recommended"

                )

            if self.num_samples_per_bucket < 100:

                raise ValueError(

                    f"Stoke -- Number of samples per bucket {self.num_samples_per_bucket} is less than 100 "

                    f"which is not recommended as this might lead to dropping of excessive data"

                )

            # Split into buckets and turn into lists

            self.bucket_idx = [

                list(val) for val in np.array_split(self.sorted_n_samples, self.buckets)

            ]

            # Calculate the post rounded numbers

            self.rounded_num_samples_per_bucket = (

                self.slice_size * self.num_slices_per_bucket

            )

            self.rounded_num_samples_per_replica = (

                self.num_slices_per_bucket * self.batch_size * self.buckets

            )

            # Add the bucket overlap samples

            if self.allow_bucket_overlap:

                self.rounded_num_samples_per_replica += (

                    (len(dataset) - (self.rounded_num_samples_per_bucket * self.buckets))

                    // self.slice_size

                ) * self.batch_size

            if self.rank == info_rank:

                print(

                    f"Stoke -- BucketedDistributedSampler -- # Samples Per Bucket: "

                    f"{self.rounded_num_samples_per_bucket}, # of Samples Per Replica: "

                    f"{self.rounded_num_samples_per_replica}"

                )



        def _conditional_distributed(

            self,

            backend: DistributedOptions,

            num_replicas: Optional[int],

            rank: Optional[int],

        ):

            """



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used

            num_replicas: int, default: None

                total number of replicas

            rank: int, default: None

                current device rank



            Returns

            -------

            Tuple[int, int]

                num_replicas, rank

            """

            return self._check_backend(backend, num_replicas, rank)



        def _get_backend_functions(self, backend: DistributedOptions):

            """Gets backend functions if needed



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used



            Returns

            -------

            Tuple[bool, int, int]

                is_init, num_replicas, rank



            """

            if backend.value == "ddp" or backend.value == "deepspeed":

                return (

                    torch.distributed.is_initialized,

                    torch.distributed.get_world_size,

                    torch.distributed.get_rank,

                )

            else:

                return hvd.is_initialized, hvd.size, hvd.rank



        def _check_backend(

            self,

            backend: DistributedOptions,

            num_replicas: Optional[int],

            rank: Optional[int],

        ):

            """Checks the backend for correct device info



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used

            num_replicas: int, default: None

                total number of replicas

            rank: int, default: None

                current device rank



            Returns

            -------

            Tuple[int, int]

                num_replicas, rank



            """

            if num_replicas is None or rank is None:

                is_avail, get_world_size, get_rank = self._get_backend_functions(

                    backend=backend

                )

            if num_replicas is None:

                if not is_avail():

                    raise RuntimeError(

                        "Requires distributed package (torch.dist or hvd) to be available"

                    )

                num_replicas = get_world_size()

            if rank is None:

                if not is_avail():

                    raise RuntimeError(

                        "Requires distributed package (torch.dist or hvd) to be available"

                    )

                rank = get_rank()

            return num_replicas, rank



        @staticmethod

        def _get_size(data_len: int, split_var: int, drop_last: bool = False):

            """Gets the size of a split



            Parameters

            ----------

            data_len: int

                current dataset length

            split_var: int

                how many to split into

            drop_last: bool, default: False

                drop last hanging samples if not batch_size



            Returns

            -------

            num_samples: int



            """

            if drop_last:

                num_samples = data_len // split_var

            else:

                num_samples = ceil(data_len / split_var)

            return num_samples



        def __iter__(self) -> Iterator[T_co]:

            """Handles assembling the batches from a bucketed perspective



            Shuffle bucket order->Pad if necessary->Slice across replicas->Possibly batch up residuals->shuffle bucketed

            batches->Unroll into list->Make iter



            Returns

            -------

            Iterator[T_co]



            """

            # Shuffle the bucketed idx

            if self.shuffle:

                # deterministically shuffle based on epoch and seed

                g = torch.Generator()

                g.manual_seed(self.seed + self.epoch)

                # Permute each bucket

                indices = [

                    [val[idx] for idx in torch.randperm(len(val), generator=g).tolist()]

                    for val in self.bucket_idx

                ]

            else:

                indices = self.bucket_idx

            # Iterate over the buckets

            for idx, val in enumerate(indices):

                # If this is true we need to handle padding

                if (self.num_slices_per_bucket * self.slice_size) > len(val):

                    split_val = self._handle_padding(val)

                    indices[idx] = list(itertools.chain(*split_val))

                    assert len(indices[idx]) == self.rounded_num_samples_per_bucket

            # Now slice across replicas

            final_indices = []

            for val in indices:

                for idx in range(self.num_slices_per_bucket):

                    replica_slice = val[

                        (idx * self.slice_size) : ((idx + 1) * self.slice_size)

                    ][self.rank : self.slice_size : self.num_replicas]

                    final_indices.append(replica_slice)

            # If bucket overlap is allowed then we just batch up the residual indices

            if self.drop_last and self.allow_bucket_overlap:

                residual_idx = list(

                    itertools.chain(

                        *[val[self.rounded_num_samples_per_bucket :] for val in indices]

                    )

                )

                if len(residual_idx) > self.slice_size:

                    # Cut by slices then by replicas

                    residual_idx = [

                        residual_idx[

                            (idx * self.slice_size) : ((idx + 1) * self.slice_size)

                        ][self.rank : self.slice_size : self.num_replicas]

                        for idx in range(len(residual_idx) // self.slice_size)

                    ]

                    # Append to the final indices

                    final_indices.extend(residual_idx)

            # Shuffle the bucketed batches

            if self.shuffle:

                # deterministically shuffle based on epoch and seed

                g = torch.Generator()

                g.manual_seed(self.seed + self.epoch)

                # Permute the bucket order

                final_indices = [

                    final_indices[val]

                    for val in torch.randperm(len(final_indices), generator=g)

                ]

            # Unroll into a single list

            final_indices = list(itertools.chain(*final_indices))

            assert len(final_indices) == self.rounded_num_samples_per_replica

            return iter(final_indices)



        def _handle_padding(self, idx_list: List):

            """Handles padding out if a batch is short



            Parameters

            ----------

            idx_list: List

                list of indices



            Returns

            -------

            split_val: List

                list with correctly padded sizes



            """

            split_val = []

            for idx in range(self.num_slices_per_bucket):

                if idx == (self.num_slices_per_bucket - 1):

                    # Get the short batch

                    short_batch = idx_list[(idx * self.slice_size) :]

                    # Short batch replica slice sizes

                    short_len = [

                        self.batch_size - len(list(val))

                        for val in np.array_split(short_batch, self.num_replicas)

                    ]

                    # Pop the necessary values from the entire bucket

                    pad_values = [

                        idx_list[s_idx : (self.num_replicas * s_len) : self.num_replicas]

                        for s_idx, s_len in enumerate(short_len)

                    ]

                    # If not a consistent list then we need to reorder so that the step size alignment slicing

                    # of the replicas works

                    if len(set(short_len)) != 1:

                        # here we need to find the first larger idx and reorder

                        first_idx = short_len.index(max(set(short_len)))

                        # Reorder

                        pad_values = pad_values[first_idx:] + pad_values[0:first_idx]

                    extended_batch = short_batch + [

                        pad

                        for pad in list(

                            itertools.chain(*itertools.zip_longest(*pad_values))

                        )

                        if pad is not None

                    ]

                    split_val.append(extended_batch)

                else:

                    split_val.append(

                        idx_list[(idx * self.slice_size) : ((idx + 1) * self.slice_size)]

                    )

            return split_val



        def __len__(self) -> int:

            return self.rounded_num_samples_per_replica



        def set_epoch(self, epoch: int) -> None:

            """Sets the epoch for this sampler.



            When :attr:`shuffle=True`, this ensures all replicas

            use a different random ordering for each epoch. Otherwise, the next iteration of this

            sampler will yield the same ordering.



            Parameters

            ----------

            epoch: int

                Epoch number



            """

            self.epoch = epoch

Classes

BucketedDistributedSampler

class BucketedDistributedSampler(
    dataset: torch.utils.data.dataset.Dataset,
    buckets: int,
    batch_size: int,
    sorted_idx: List,
    backend: stoke.status.DistributedOptions,
    allow_bucket_overlap: bool = False,
    num_replicas: Union[int, NoneType] = None,
    rank: Union[int, NoneType] = None,
    shuffle: bool = True,
    seed: int = 0,
    drop_last: bool = False,
    info_rank: int = 0
)

Attributes

Name Type Description Default
num_replicas int, default: None number of replicas None
rank int, default: None current device rank None
epoch int current training epoch None
drop_last bool, default: False whether to drop last set of samples that don't fit into a batch None
shuffle bool, default: True flag to shuffle dataset None
seed int, default: 0 seed to use for generators None
buckets int number of buckets to break the dataset into None
sorted_n_samples list sorted list of samples by the characteristic to bucket by (e.g. seq len) None
batch_size int batch size that will be used (needed to make sure slices are correct) None
allow_bucket_overlap bool, default: False allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into
an un-bucketed batch None
slice_size int computed from batch size and number of replicas None
num_samples_per_bucket int computed value that represents the number of samples in a single bucket None
num_slices_per_bucket int computed value that represents the number of slices available in a bucket None
bucket_idx list computed value that make a contiguous list of indices in each bucket None
rounded_num_samples_per_bucket int computed value post round for number of samples in a single bucket None
rounded_num_samples_per_replica int computed value post round for number of slices available in a bucket None

??? example "View Source" class BucketedDistributedSampler(Sampler[T_co]):

        """Sampler that buckets samples by sorted_idx and then randomly samples from a specific bucket to prevent excess

        padding leading to wasted computation



        Borrowing heavily from the base DistributedSampler

        https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler



        Attributes

        ----------

        num_replicas: int, default: None

            number of replicas

        rank: int, default: None

            current device rank

        epoch: int

            current training epoch

        drop_last: bool, default: False

            whether to drop last set of samples that don't fit into a batch

        shuffle: bool, default: True

            flag to shuffle dataset

        seed: int, default: 0

            seed to use for generators

        buckets: int

            number of buckets to break the dataset into

        sorted_n_samples: list

            sorted list of samples by the characteristic to bucket by (e.g. seq len)

        batch_size: int

            batch size that will be used (needed to make sure slices are correct)

        allow_bucket_overlap: bool, default: False

            allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into

            an un-bucketed batch

        slice_size: int

            computed from batch size and number of replicas

        num_samples_per_bucket: int

            computed value that represents the number of samples in a single bucket

        num_slices_per_bucket: int

            computed value that represents the number of slices available in a bucket

        bucket_idx: list

            computed value that make a contiguous list of indices in each bucket

        rounded_num_samples_per_bucket: int

            computed value post round for number of samples in a single bucket

        rounded_num_samples_per_replica: int

            computed value post round for number of slices available in a bucket



        """



        def __init__(

            self,

            dataset: Dataset,

            buckets: int,

            batch_size: int,

            sorted_idx: List,

            backend: DistributedOptions,

            allow_bucket_overlap: bool = False,

            num_replicas: Optional[int] = None,

            rank: Optional[int] = None,

            shuffle: bool = True,

            seed: int = 0,

            drop_last: bool = False,

            info_rank: int = 0,

        ) -> None:

            """Init for BucketedDistributedSampler



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            buckets: int

                number of buckets to break the dataset into

            batch_size: int

                batch size that will be used (needed to make sure slices are correct)

            sorted_idx: list

                sorted list of samples by the characteristic to bucket by (e.g. seq le

            backend: DistributedOptions

                which backend is being used (as rank, world size, etc. need to be used)

            allow_bucket_overlap: bool, default: False

                allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into

                an un-bucketed batch

            num_replicas: int, default: None

                number of replicas

            rank: int, default: None

                current device rank

            shuffle: bool, default: True

                flag to shuffle dataset

            seed: int, default: 0

                seed to use for generators

            drop_last: bool, default: False

                whether to drop last set of samples that don't fit into a

            info_rank: int, default: 0

                which device to print information on



            """

            # If the backend isnt DDP there needs to be an additional import

            num_replicas, rank = self._conditional_distributed(

                backend=backend, num_replicas=num_replicas, rank=rank

            )

            self.num_replicas = num_replicas

            self.rank = rank

            self.epoch = 0

            self.drop_last = drop_last

            self.shuffle = shuffle

            self.seed = seed

            self.buckets = buckets

            self.sorted_n_samples = sorted_idx

            # Batch size is needed here so a contiguous iter of buckets can be formed

            self.batch_size = batch_size

            # This is a flag to batch up the dropped samples (that would be 'wasted') if drop_last is flagged

            self.allow_bucket_overlap = allow_bucket_overlap

            # Calculate the size of each slice that will be indexed across the replicas

            self.slice_size = self.batch_size * self.num_replicas

            # Calculate the size of the buckets (rounded or not based on drop last)

            self.num_samples_per_bucket = self._get_size(

                len(dataset), self.buckets, self.drop_last

            )

            # Calculate the number of slices per bucket

            self.num_slices_per_bucket = self._get_size(

                self.num_samples_per_bucket, self.slice_size, self.drop_last

            )

            if self.num_samples_per_bucket < self.slice_size:

                raise ValueError(

                    f"Stoke -- Resulting number of slices (batch * replicas) per bucket "

                    f"({self.num_samples_per_bucket}) is less than the batch size "

                    f"({self.batch_size})"

                )

            if self.num_slices_per_bucket < 2:

                raise ValueError(

                    f"Stoke -- Number of slices per bucket {self.num_slices_per_bucket} is less than 2 "

                    f"which is not recommended"

                )

            if self.num_samples_per_bucket < 100:

                raise ValueError(

                    f"Stoke -- Number of samples per bucket {self.num_samples_per_bucket} is less than 100 "

                    f"which is not recommended as this might lead to dropping of excessive data"

                )

            # Split into buckets and turn into lists

            self.bucket_idx = [

                list(val) for val in np.array_split(self.sorted_n_samples, self.buckets)

            ]

            # Calculate the post rounded numbers

            self.rounded_num_samples_per_bucket = (

                self.slice_size * self.num_slices_per_bucket

            )

            self.rounded_num_samples_per_replica = (

                self.num_slices_per_bucket * self.batch_size * self.buckets

            )

            # Add the bucket overlap samples

            if self.allow_bucket_overlap:

                self.rounded_num_samples_per_replica += (

                    (len(dataset) - (self.rounded_num_samples_per_bucket * self.buckets))

                    // self.slice_size

                ) * self.batch_size

            if self.rank == info_rank:

                print(

                    f"Stoke -- BucketedDistributedSampler -- # Samples Per Bucket: "

                    f"{self.rounded_num_samples_per_bucket}, # of Samples Per Replica: "

                    f"{self.rounded_num_samples_per_replica}"

                )



        def _conditional_distributed(

            self,

            backend: DistributedOptions,

            num_replicas: Optional[int],

            rank: Optional[int],

        ):

            """



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used

            num_replicas: int, default: None

                total number of replicas

            rank: int, default: None

                current device rank



            Returns

            -------

            Tuple[int, int]

                num_replicas, rank

            """

            return self._check_backend(backend, num_replicas, rank)



        def _get_backend_functions(self, backend: DistributedOptions):

            """Gets backend functions if needed



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used



            Returns

            -------

            Tuple[bool, int, int]

                is_init, num_replicas, rank



            """

            if backend.value == "ddp" or backend.value == "deepspeed":

                return (

                    torch.distributed.is_initialized,

                    torch.distributed.get_world_size,

                    torch.distributed.get_rank,

                )

            else:

                return hvd.is_initialized, hvd.size, hvd.rank



        def _check_backend(

            self,

            backend: DistributedOptions,

            num_replicas: Optional[int],

            rank: Optional[int],

        ):

            """Checks the backend for correct device info



            Parameters

            ----------

            backend: DistributedOptions

                which backend is being used

            num_replicas: int, default: None

                total number of replicas

            rank: int, default: None

                current device rank



            Returns

            -------

            Tuple[int, int]

                num_replicas, rank



            """

            if num_replicas is None or rank is None:

                is_avail, get_world_size, get_rank = self._get_backend_functions(

                    backend=backend

                )

            if num_replicas is None:

                if not is_avail():

                    raise RuntimeError(

                        "Requires distributed package (torch.dist or hvd) to be available"

                    )

                num_replicas = get_world_size()

            if rank is None:

                if not is_avail():

                    raise RuntimeError(

                        "Requires distributed package (torch.dist or hvd) to be available"

                    )

                rank = get_rank()

            return num_replicas, rank



        @staticmethod

        def _get_size(data_len: int, split_var: int, drop_last: bool = False):

            """Gets the size of a split



            Parameters

            ----------

            data_len: int

                current dataset length

            split_var: int

                how many to split into

            drop_last: bool, default: False

                drop last hanging samples if not batch_size



            Returns

            -------

            num_samples: int



            """

            if drop_last:

                num_samples = data_len // split_var

            else:

                num_samples = ceil(data_len / split_var)

            return num_samples



        def __iter__(self) -> Iterator[T_co]:

            """Handles assembling the batches from a bucketed perspective



            Shuffle bucket order->Pad if necessary->Slice across replicas->Possibly batch up residuals->shuffle bucketed

            batches->Unroll into list->Make iter



            Returns

            -------

            Iterator[T_co]



            """

            # Shuffle the bucketed idx

            if self.shuffle:

                # deterministically shuffle based on epoch and seed

                g = torch.Generator()

                g.manual_seed(self.seed + self.epoch)

                # Permute each bucket

                indices = [

                    [val[idx] for idx in torch.randperm(len(val), generator=g).tolist()]

                    for val in self.bucket_idx

                ]

            else:

                indices = self.bucket_idx

            # Iterate over the buckets

            for idx, val in enumerate(indices):

                # If this is true we need to handle padding

                if (self.num_slices_per_bucket * self.slice_size) > len(val):

                    split_val = self._handle_padding(val)

                    indices[idx] = list(itertools.chain(*split_val))

                    assert len(indices[idx]) == self.rounded_num_samples_per_bucket

            # Now slice across replicas

            final_indices = []

            for val in indices:

                for idx in range(self.num_slices_per_bucket):

                    replica_slice = val[

                        (idx * self.slice_size) : ((idx + 1) * self.slice_size)

                    ][self.rank : self.slice_size : self.num_replicas]

                    final_indices.append(replica_slice)

            # If bucket overlap is allowed then we just batch up the residual indices

            if self.drop_last and self.allow_bucket_overlap:

                residual_idx = list(

                    itertools.chain(

                        *[val[self.rounded_num_samples_per_bucket :] for val in indices]

                    )

                )

                if len(residual_idx) > self.slice_size:

                    # Cut by slices then by replicas

                    residual_idx = [

                        residual_idx[

                            (idx * self.slice_size) : ((idx + 1) * self.slice_size)

                        ][self.rank : self.slice_size : self.num_replicas]

                        for idx in range(len(residual_idx) // self.slice_size)

                    ]

                    # Append to the final indices

                    final_indices.extend(residual_idx)

            # Shuffle the bucketed batches

            if self.shuffle:

                # deterministically shuffle based on epoch and seed

                g = torch.Generator()

                g.manual_seed(self.seed + self.epoch)

                # Permute the bucket order

                final_indices = [

                    final_indices[val]

                    for val in torch.randperm(len(final_indices), generator=g)

                ]

            # Unroll into a single list

            final_indices = list(itertools.chain(*final_indices))

            assert len(final_indices) == self.rounded_num_samples_per_replica

            return iter(final_indices)



        def _handle_padding(self, idx_list: List):

            """Handles padding out if a batch is short



            Parameters

            ----------

            idx_list: List

                list of indices



            Returns

            -------

            split_val: List

                list with correctly padded sizes



            """

            split_val = []

            for idx in range(self.num_slices_per_bucket):

                if idx == (self.num_slices_per_bucket - 1):

                    # Get the short batch

                    short_batch = idx_list[(idx * self.slice_size) :]

                    # Short batch replica slice sizes

                    short_len = [

                        self.batch_size - len(list(val))

                        for val in np.array_split(short_batch, self.num_replicas)

                    ]

                    # Pop the necessary values from the entire bucket

                    pad_values = [

                        idx_list[s_idx : (self.num_replicas * s_len) : self.num_replicas]

                        for s_idx, s_len in enumerate(short_len)

                    ]

                    # If not a consistent list then we need to reorder so that the step size alignment slicing

                    # of the replicas works

                    if len(set(short_len)) != 1:

                        # here we need to find the first larger idx and reorder

                        first_idx = short_len.index(max(set(short_len)))

                        # Reorder

                        pad_values = pad_values[first_idx:] + pad_values[0:first_idx]

                    extended_batch = short_batch + [

                        pad

                        for pad in list(

                            itertools.chain(*itertools.zip_longest(*pad_values))

                        )

                        if pad is not None

                    ]

                    split_val.append(extended_batch)

                else:

                    split_val.append(

                        idx_list[(idx * self.slice_size) : ((idx + 1) * self.slice_size)]

                    )

            return split_val



        def __len__(self) -> int:

            return self.rounded_num_samples_per_replica



        def set_epoch(self, epoch: int) -> None:

            """Sets the epoch for this sampler.



            When :attr:`shuffle=True`, this ensures all replicas

            use a different random ordering for each epoch. Otherwise, the next iteration of this

            sampler will yield the same ordering.



            Parameters

            ----------

            epoch: int

                Epoch number



            """

            self.epoch = epoch

Ancestors (in MRO)

  • torch.utils.data.sampler.Sampler
  • typing.Generic

Methods

set_epoch

def set_epoch(
    self,
    epoch: int
) -> None

Sets the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int Epoch number None

??? example "View Source" def set_epoch(self, epoch: int) -> None:

            """Sets the epoch for this sampler.



            When :attr:`shuffle=True`, this ensures all replicas

            use a different random ordering for each epoch. Otherwise, the next iteration of this

            sampler will yield the same ordering.



            Parameters

            ----------

            epoch: int

                Epoch number



            """

            self.epoch = epoch

StokeDataLoader

class StokeDataLoader(
    gpu: bool,
    fp16: Union[stoke.status.FP16Options, NoneType],
    dataset: torch.utils.data.dataset.Dataset[+T_co],
    batch_size: Union[int, NoneType] = 1,
    shuffle: bool = False,
    sampler: Union[torch.utils.data.sampler.Sampler[int], NoneType] = None,
    batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence[int]], NoneType] = None,
    num_workers: int = 0,
    collate_fn: Callable[[List[~T]], Any] = None,
    pin_memory: bool = False,
    drop_last: bool = False,
    timeout: float = 0,
    worker_init_fn: Callable[[int], NoneType] = None,
    multiprocessing_context=None,
    generator=None,
    *,
    prefetch_factor: int = 2,
    persistent_workers: bool = False
)

Attributes

Name Type Description Default
_gpu bool None None
_fp16 Optional[FP16Options] None None

??? example "View Source" class StokeDataLoader(DL):

        """Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs



        Attributes

        ----------

        _gpu: bool

        _fp16: Optional[FP16Options]



        See Also

        --------

        torch.utils.data.DataLoader: base DataLoader class that this inherits from (check for all attributes)



        """



        def __init__(

            self,

            gpu: bool,

            fp16: Optional[FP16Options],

            dataset: Dataset[T_co],

            batch_size: Optional[int] = 1,

            shuffle: bool = False,

            sampler: Optional[Sampler[int]] = None,

            batch_sampler: Optional[Sampler[Sequence[int]]] = None,

            num_workers: int = 0,

            collate_fn: _collate_fn_t = None,

            pin_memory: bool = False,

            drop_last: bool = False,

            timeout: float = 0,

            worker_init_fn: _worker_init_fn_t = None,

            multiprocessing_context=None,

            generator=None,

            *,

            prefetch_factor: int = 2,

            persistent_workers: bool = False,

        ):

            """Maps to torch.utils.data.DataLoader __init__



            Shim is necessary to automatically handle device placement since the gpu/fp16 flags can't be

            determined until the StokeStatus object is available which is post init. This could be disconnected from

            this class but it would require the user to forward on device or fp16 configs which breaks the

            paradigm that the flags only need to be set and never handled



            Parameters

            ----------

            dataset: Dataset

                dataset from which to load the data.

            batch_size: int, default: 1

                how many samples per batch to load .

            shuffle: bool, default: False

                set to ``True`` to have the data reshuffled at every epoch.

            sampler: Sampler or Iterable, default: None

                defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``

                implemented. If specified, :attr:`shuffle` must not be specified.

            batch_sampler: Sampler or Iterable, default: None:

                like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with

                :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.

            num_workers: int, default: 0

                how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.

            collate_fn: callable, optional:

                merges a list of samples to form a mini-batch of Tensor(s).  Used when using batched loading from a

                map-style dataset.

            pin_memory: bool, default: False:

                If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your

                data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,

                see the example below.

            drop_last: bool, default: False

                set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.

                If ``False`` and the size of dataset is not divisible by the batch size, then the last batch

                will be smaller.

            timeout: numeric, default: 0

                if positive, the timeout value for collecting a batch from workers. Should always be non-negative.

            worker_init_fn: callable, default: None

                If not ``None``, this will be called on each worker subprocess with the worker id

                (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.

            prefetch_factor: int, default: 2

                Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers

                samples prefetched across all workers.

            persistent_workers: bool, default: False

                If ``True``, the data loader will not shutdown the worker processes after a dataset has been

                consumed once. This allows to maintain the workers `Dataset` instances alive.



            Returns

            -------

            StokeDataLoader

                wrapped torch.utils.data.DataLoader object



            """

            # Call super init for the actual torch DataLoader

            super(StokeDataLoader, self).__init__(

                dataset=dataset,

                batch_size=batch_size,

                shuffle=shuffle,

                sampler=sampler,

                batch_sampler=batch_sampler,

                num_workers=num_workers,

                collate_fn=collate_fn,

                pin_memory=pin_memory,

                drop_last=drop_last,

                timeout=timeout,

                worker_init_fn=worker_init_fn,

                multiprocessing_context=multiprocessing_context,

                generator=generator,

                prefetch_factor=prefetch_factor,

                persistent_workers=persistent_workers,

            )

            self._gpu = gpu

            self._fp16 = fp16



        def __iter__(self):

            """Underlying iter of the DataLoader that yields samples



            Wrap the base __iter__ with a call to place on the device if flagged



            Yields

            ------

            Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]

                data placed on the correct device



            """

            # Iterate using the base class iter but override the yield by pushing to device prior if gpu flag is true

            for val in super().__iter__():

                yield val if not self._gpu else self._place_data_on_gpu(val)



        def _place_data_on_gpu(

            self,

            data: Union[

                torch.Tensor,

                List[torch.Tensor],

                Tuple[torch.Tensor],

                Dict[str, torch.Tensor],

            ],

        ):

            """Determine data structure and then place on the correct device (cast in the context of deepspeed FP16 as it

            wants half dtype as input)



            Parameters

            ----------

            data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]

                current data coming from the underlying __iter__



            Returns

            -------

            data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]

                data moved to the correct device



            """

            if isinstance(data, torch.Tensor):

                # TODO: Check if one of the APEX version needs a cast too?

                # Move to the correct cuda device w/ the correct type -- deepspeed FP16 requires a cast to half if fp16

                if self._fp16 == "deepspeed":

                    return data.to(device="cuda", dtype=torch.half)

                else:

                    return data.to(device="cuda", dtype=data.dtype)

            elif isinstance(data, (list, tuple)):

                return type(data)(self._place_data_on_gpu(data=val) for val in data)

            elif isinstance(data, dict):

                return {k: self._place_data_on_gpu(v) for k, v in data.items()}

            elif ~(hasattr(data, "to")):

                return data

            else:

                raise TypeError(

                    f"Stoke -- Unsupported data type passed to _place_data_on_gpu "

                    f"(torch.Tensor, tuple, list, dict), currently {type(data)}"

                )

Ancestors (in MRO)

  • torch.utils.data.dataloader.DataLoader
  • typing.Generic

Instance variables

multiprocessing_context

Methods

check_worker_number_rationality

def check_worker_number_rationality(
    self
)

??? example "View Source" def check_worker_number_rationality(self):

            # This function check whether the dataloader's worker number is rational based on

            # current system's resource. Current rule is that if the number of workers this

            # Dataloader will create is bigger than the number of logical cpus that is allowed to

            # use, than we will pop up a warning to let user pay attention.

            #

            # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2

            #     threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current

            #     DataLoader process can use half of them which is 32, then the rational max number of

            #     worker that initiated from this process is 32.

            #     Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.

            #     So the warning message is triggered to notify the user to lower the worker number if

            #     necessary.

            #

            #

            # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is

            #        available (available in most of Linux system, but not OSX and Windows).

            #        When os.sched_getaffinity is not available, os.cpu_count() is called instead, but

            #        it doesn't repect cpuset.

            #        We don't take threading into account since each worker process is single threaded

            #        at this time.

            #

            #        We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)

            #        other than `torch.set_num_threads` to 1 in the worker process, if the passing

            #        in functions use 3rd party modules that rely on those threading flags to determine

            #        how many thread to create (eg. numpy, etc), then it is caller's responsibility to

            #        set those flags correctly.

            def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):



                suggested_max_worker_msg = ((

                    "Our suggested max number of worker in current system is {}{}, which is smaller "

                    "than what this DataLoader is going to create.").format(

                        num_worker_suggest,

                        ("" if cpuset_checked else " (`cpuset` is not taken into account)"))

                ) if num_worker_suggest is not None else (

                    "DataLoader is not able to compute a suggested max number of worker in current system.")



                warn_msg = (

                    "This DataLoader will create {} worker processes in total. {} "

                    "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "

                    "lower the worker number to avoid potential slowness/freeze if necessary.").format(

                        num_worker_created,

                        suggested_max_worker_msg)

                return warn_msg



            if not self.num_workers or self.num_workers == 0:

                return



            # try to compute a suggested max number of worker based on system's resource

            max_num_worker_suggest = None

            cpuset_checked = False

            if hasattr(os, 'sched_getaffinity'):

                try:

                    max_num_worker_suggest = len(os.sched_getaffinity(0))

                    cpuset_checked = True

                except Exception:

                    pass

            if max_num_worker_suggest is None:

                # os.cpu_count() could return Optional[int]

                # get cpu count first and check None in order to satify mypy check

                cpu_count = os.cpu_count()

                if cpu_count is not None:

                    max_num_worker_suggest = cpu_count



            if max_num_worker_suggest is None:

                warnings.warn(_create_warning_msg(

                    max_num_worker_suggest,

                    self.num_workers,

                    cpuset_checked))

                return



            if self.num_workers > max_num_worker_suggest:

                warnings.warn(_create_warning_msg(

                    max_num_worker_suggest,

                    self.num_workers,

                    cpuset_checked))
Back to top