Module stoke.data
Handles any data (e.g. loader, sampler, etc.) related classes
None
??? example "View Source" # -- coding: utf-8 --
# Copyright FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0
"""Handles any data (e.g. loader, sampler, etc.) related classes"""
import itertools
from math import ceil
from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union
import horovod.torch as hvd
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader as DL
from torch.utils.data import Dataset
from torch.utils.data.distributed import Sampler
from stoke.status import DistributedOptions, FP16Options
from stoke.utils import T_co, _collate_fn_t, _worker_init_fn_t
class StokeDataLoader(DL):
"""Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs
Attributes
----------
_gpu: bool
_fp16: Optional[FP16Options]
See Also
--------
torch.utils.data.DataLoader: base DataLoader class that this inherits from (check for all attributes)
"""
def __init__(
self,
gpu: bool,
fp16: Optional[FP16Options],
dataset: Dataset[T_co],
batch_size: Optional[int] = 1,
shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0,
collate_fn: _collate_fn_t = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
):
"""Maps to torch.utils.data.DataLoader __init__
Shim is necessary to automatically handle device placement since the gpu/fp16 flags can't be
determined until the StokeStatus object is available which is post init. This could be disconnected from
this class but it would require the user to forward on device or fp16 configs which breaks the
paradigm that the flags only need to be set and never handled
Parameters
----------
dataset: Dataset
dataset from which to load the data.
batch_size: int, default: 1
how many samples per batch to load .
shuffle: bool, default: False
set to ``True`` to have the data reshuffled at every epoch.
sampler: Sampler or Iterable, default: None
defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler: Sampler or Iterable, default: None:
like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers: int, default: 0
how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.
collate_fn: callable, optional:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory: bool, default: False:
If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your
data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last: bool, default: False
set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
will be smaller.
timeout: numeric, default: 0
if positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn: callable, default: None
If not ``None``, this will be called on each worker subprocess with the worker id
(an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.
prefetch_factor: int, default: 2
Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers
samples prefetched across all workers.
persistent_workers: bool, default: False
If ``True``, the data loader will not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers `Dataset` instances alive.
Returns
-------
StokeDataLoader
wrapped torch.utils.data.DataLoader object
"""
# Call super init for the actual torch DataLoader
super(StokeDataLoader, self).__init__(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
self._gpu = gpu
self._fp16 = fp16
def __iter__(self):
"""Underlying iter of the DataLoader that yields samples
Wrap the base __iter__ with a call to place on the device if flagged
Yields
------
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]
data placed on the correct device
"""
# Iterate using the base class iter but override the yield by pushing to device prior if gpu flag is true
for val in super().__iter__():
yield val if not self._gpu else self._place_data_on_gpu(val)
def _place_data_on_gpu(
self,
data: Union[
torch.Tensor,
List[torch.Tensor],
Tuple[torch.Tensor],
Dict[str, torch.Tensor],
],
):
"""Determine data structure and then place on the correct device (cast in the context of deepspeed FP16 as it
wants half dtype as input)
Parameters
----------
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]
current data coming from the underlying __iter__
Returns
-------
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]
data moved to the correct device
"""
if isinstance(data, torch.Tensor):
# TODO: Check if one of the APEX version needs a cast too?
# Move to the correct cuda device w/ the correct type -- deepspeed FP16 requires a cast to half if fp16
if self._fp16 == "deepspeed":
return data.to(device="cuda", dtype=torch.half)
else:
return data.to(device="cuda", dtype=data.dtype)
elif isinstance(data, (list, tuple)):
return type(data)(self._place_data_on_gpu(data=val) for val in data)
elif isinstance(data, dict):
return {k: self._place_data_on_gpu(v) for k, v in data.items()}
elif ~(hasattr(data, "to")):
return data
else:
raise TypeError(
f"Stoke -- Unsupported data type passed to _place_data_on_gpu "
f"(torch.Tensor, tuple, list, dict), currently {type(data)}"
)
class BucketedDistributedSampler(Sampler[T_co]):
"""Sampler that buckets samples by sorted_idx and then randomly samples from a specific bucket to prevent excess
padding leading to wasted computation
Borrowing heavily from the base DistributedSampler
https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler
Attributes
----------
num_replicas: int, default: None
number of replicas
rank: int, default: None
current device rank
epoch: int
current training epoch
drop_last: bool, default: False
whether to drop last set of samples that don't fit into a batch
shuffle: bool, default: True
flag to shuffle dataset
seed: int, default: 0
seed to use for generators
buckets: int
number of buckets to break the dataset into
sorted_n_samples: list
sorted list of samples by the characteristic to bucket by (e.g. seq len)
batch_size: int
batch size that will be used (needed to make sure slices are correct)
allow_bucket_overlap: bool, default: False
allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into
an un-bucketed batch
slice_size: int
computed from batch size and number of replicas
num_samples_per_bucket: int
computed value that represents the number of samples in a single bucket
num_slices_per_bucket: int
computed value that represents the number of slices available in a bucket
bucket_idx: list
computed value that make a contiguous list of indices in each bucket
rounded_num_samples_per_bucket: int
computed value post round for number of samples in a single bucket
rounded_num_samples_per_replica: int
computed value post round for number of slices available in a bucket
"""
def __init__(
self,
dataset: Dataset,
buckets: int,
batch_size: int,
sorted_idx: List,
backend: DistributedOptions,
allow_bucket_overlap: bool = False,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
info_rank: int = 0,
) -> None:
"""Init for BucketedDistributedSampler
Parameters
----------
dataset: Dataset
dataset from which to load the data.
buckets: int
number of buckets to break the dataset into
batch_size: int
batch size that will be used (needed to make sure slices are correct)
sorted_idx: list
sorted list of samples by the characteristic to bucket by (e.g. seq le
backend: DistributedOptions
which backend is being used (as rank, world size, etc. need to be used)
allow_bucket_overlap: bool, default: False
allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into
an un-bucketed batch
num_replicas: int, default: None
number of replicas
rank: int, default: None
current device rank
shuffle: bool, default: True
flag to shuffle dataset
seed: int, default: 0
seed to use for generators
drop_last: bool, default: False
whether to drop last set of samples that don't fit into a
info_rank: int, default: 0
which device to print information on
"""
# If the backend isnt DDP there needs to be an additional import
num_replicas, rank = self._conditional_distributed(
backend=backend, num_replicas=num_replicas, rank=rank
)
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
self.buckets = buckets
self.sorted_n_samples = sorted_idx
# Batch size is needed here so a contiguous iter of buckets can be formed
self.batch_size = batch_size
# This is a flag to batch up the dropped samples (that would be 'wasted') if drop_last is flagged
self.allow_bucket_overlap = allow_bucket_overlap
# Calculate the size of each slice that will be indexed across the replicas
self.slice_size = self.batch_size * self.num_replicas
# Calculate the size of the buckets (rounded or not based on drop last)
self.num_samples_per_bucket = self._get_size(
len(dataset), self.buckets, self.drop_last
)
# Calculate the number of slices per bucket
self.num_slices_per_bucket = self._get_size(
self.num_samples_per_bucket, self.slice_size, self.drop_last
)
if self.num_samples_per_bucket < self.slice_size:
raise ValueError(
f"Stoke -- Resulting number of slices (batch * replicas) per bucket "
f"({self.num_samples_per_bucket}) is less than the batch size "
f"({self.batch_size})"
)
if self.num_slices_per_bucket < 2:
raise ValueError(
f"Stoke -- Number of slices per bucket {self.num_slices_per_bucket} is less than 2 "
f"which is not recommended"
)
if self.num_samples_per_bucket < 100:
raise ValueError(
f"Stoke -- Number of samples per bucket {self.num_samples_per_bucket} is less than 100 "
f"which is not recommended as this might lead to dropping of excessive data"
)
# Split into buckets and turn into lists
self.bucket_idx = [
list(val) for val in np.array_split(self.sorted_n_samples, self.buckets)
]
# Calculate the post rounded numbers
self.rounded_num_samples_per_bucket = (
self.slice_size * self.num_slices_per_bucket
)
self.rounded_num_samples_per_replica = (
self.num_slices_per_bucket * self.batch_size * self.buckets
)
# Add the bucket overlap samples
if self.allow_bucket_overlap:
self.rounded_num_samples_per_replica += (
(len(dataset) - (self.rounded_num_samples_per_bucket * self.buckets))
// self.slice_size
) * self.batch_size
if self.rank == info_rank:
print(
f"Stoke -- BucketedDistributedSampler -- # Samples Per Bucket: "
f"{self.rounded_num_samples_per_bucket}, # of Samples Per Replica: "
f"{self.rounded_num_samples_per_replica}"
)
def _conditional_distributed(
self,
backend: DistributedOptions,
num_replicas: Optional[int],
rank: Optional[int],
):
"""
Parameters
----------
backend: DistributedOptions
which backend is being used
num_replicas: int, default: None
total number of replicas
rank: int, default: None
current device rank
Returns
-------
Tuple[int, int]
num_replicas, rank
"""
return self._check_backend(backend, num_replicas, rank)
def _get_backend_functions(self, backend: DistributedOptions):
"""Gets backend functions if needed
Parameters
----------
backend: DistributedOptions
which backend is being used
Returns
-------
Tuple[bool, int, int]
is_init, num_replicas, rank
"""
if backend.value == "ddp" or backend.value == "deepspeed":
return (
torch.distributed.is_initialized,
torch.distributed.get_world_size,
torch.distributed.get_rank,
)
else:
return hvd.is_initialized, hvd.size, hvd.rank
def _check_backend(
self,
backend: DistributedOptions,
num_replicas: Optional[int],
rank: Optional[int],
):
"""Checks the backend for correct device info
Parameters
----------
backend: DistributedOptions
which backend is being used
num_replicas: int, default: None
total number of replicas
rank: int, default: None
current device rank
Returns
-------
Tuple[int, int]
num_replicas, rank
"""
if num_replicas is None or rank is None:
is_avail, get_world_size, get_rank = self._get_backend_functions(
backend=backend
)
if num_replicas is None:
if not is_avail():
raise RuntimeError(
"Requires distributed package (torch.dist or hvd) to be available"
)
num_replicas = get_world_size()
if rank is None:
if not is_avail():
raise RuntimeError(
"Requires distributed package (torch.dist or hvd) to be available"
)
rank = get_rank()
return num_replicas, rank
@staticmethod
def _get_size(data_len: int, split_var: int, drop_last: bool = False):
"""Gets the size of a split
Parameters
----------
data_len: int
current dataset length
split_var: int
how many to split into
drop_last: bool, default: False
drop last hanging samples if not batch_size
Returns
-------
num_samples: int
"""
if drop_last:
num_samples = data_len // split_var
else:
num_samples = ceil(data_len / split_var)
return num_samples
def __iter__(self) -> Iterator[T_co]:
"""Handles assembling the batches from a bucketed perspective
Shuffle bucket order->Pad if necessary->Slice across replicas->Possibly batch up residuals->shuffle bucketed
batches->Unroll into list->Make iter
Returns
-------
Iterator[T_co]
"""
# Shuffle the bucketed idx
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# Permute each bucket
indices = [
[val[idx] for idx in torch.randperm(len(val), generator=g).tolist()]
for val in self.bucket_idx
]
else:
indices = self.bucket_idx
# Iterate over the buckets
for idx, val in enumerate(indices):
# If this is true we need to handle padding
if (self.num_slices_per_bucket * self.slice_size) > len(val):
split_val = self._handle_padding(val)
indices[idx] = list(itertools.chain(*split_val))
assert len(indices[idx]) == self.rounded_num_samples_per_bucket
# Now slice across replicas
final_indices = []
for val in indices:
for idx in range(self.num_slices_per_bucket):
replica_slice = val[
(idx * self.slice_size) : ((idx + 1) * self.slice_size)
][self.rank : self.slice_size : self.num_replicas]
final_indices.append(replica_slice)
# If bucket overlap is allowed then we just batch up the residual indices
if self.drop_last and self.allow_bucket_overlap:
residual_idx = list(
itertools.chain(
*[val[self.rounded_num_samples_per_bucket :] for val in indices]
)
)
if len(residual_idx) > self.slice_size:
# Cut by slices then by replicas
residual_idx = [
residual_idx[
(idx * self.slice_size) : ((idx + 1) * self.slice_size)
][self.rank : self.slice_size : self.num_replicas]
for idx in range(len(residual_idx) // self.slice_size)
]
# Append to the final indices
final_indices.extend(residual_idx)
# Shuffle the bucketed batches
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# Permute the bucket order
final_indices = [
final_indices[val]
for val in torch.randperm(len(final_indices), generator=g)
]
# Unroll into a single list
final_indices = list(itertools.chain(*final_indices))
assert len(final_indices) == self.rounded_num_samples_per_replica
return iter(final_indices)
def _handle_padding(self, idx_list: List):
"""Handles padding out if a batch is short
Parameters
----------
idx_list: List
list of indices
Returns
-------
split_val: List
list with correctly padded sizes
"""
split_val = []
for idx in range(self.num_slices_per_bucket):
if idx == (self.num_slices_per_bucket - 1):
# Get the short batch
short_batch = idx_list[(idx * self.slice_size) :]
# Short batch replica slice sizes
short_len = [
self.batch_size - len(list(val))
for val in np.array_split(short_batch, self.num_replicas)
]
# Pop the necessary values from the entire bucket
pad_values = [
idx_list[s_idx : (self.num_replicas * s_len) : self.num_replicas]
for s_idx, s_len in enumerate(short_len)
]
# If not a consistent list then we need to reorder so that the step size alignment slicing
# of the replicas works
if len(set(short_len)) != 1:
# here we need to find the first larger idx and reorder
first_idx = short_len.index(max(set(short_len)))
# Reorder
pad_values = pad_values[first_idx:] + pad_values[0:first_idx]
extended_batch = short_batch + [
pad
for pad in list(
itertools.chain(*itertools.zip_longest(*pad_values))
)
if pad is not None
]
split_val.append(extended_batch)
else:
split_val.append(
idx_list[(idx * self.slice_size) : ((idx + 1) * self.slice_size)]
)
return split_val
def __len__(self) -> int:
return self.rounded_num_samples_per_replica
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Parameters
----------
epoch: int
Epoch number
"""
self.epoch = epoch
Classes
BucketedDistributedSampler
class BucketedDistributedSampler(
dataset: torch.utils.data.dataset.Dataset,
buckets: int,
batch_size: int,
sorted_idx: List,
backend: stoke.status.DistributedOptions,
allow_bucket_overlap: bool = False,
num_replicas: Union[int, NoneType] = None,
rank: Union[int, NoneType] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
info_rank: int = 0
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
num_replicas | int, default: None | number of replicas | None |
rank | int, default: None | current device rank | None |
epoch | int | current training epoch | None |
drop_last | bool, default: False | whether to drop last set of samples that don't fit into a batch | None |
shuffle | bool, default: True | flag to shuffle dataset | None |
seed | int, default: 0 | seed to use for generators | None |
buckets | int | number of buckets to break the dataset into | None |
sorted_n_samples | list | sorted list of samples by the characteristic to bucket by (e.g. seq len) | None |
batch_size | int | batch size that will be used (needed to make sure slices are correct) | None |
allow_bucket_overlap | bool, default: False | allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into | |
an un-bucketed batch | None | ||
slice_size | int | computed from batch size and number of replicas | None |
num_samples_per_bucket | int | computed value that represents the number of samples in a single bucket | None |
num_slices_per_bucket | int | computed value that represents the number of slices available in a bucket | None |
bucket_idx | list | computed value that make a contiguous list of indices in each bucket | None |
rounded_num_samples_per_bucket | int | computed value post round for number of samples in a single bucket | None |
rounded_num_samples_per_replica | int | computed value post round for number of slices available in a bucket | None |
??? example "View Source" class BucketedDistributedSampler(Sampler[T_co]):
"""Sampler that buckets samples by sorted_idx and then randomly samples from a specific bucket to prevent excess
padding leading to wasted computation
Borrowing heavily from the base DistributedSampler
https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler
Attributes
----------
num_replicas: int, default: None
number of replicas
rank: int, default: None
current device rank
epoch: int
current training epoch
drop_last: bool, default: False
whether to drop last set of samples that don't fit into a batch
shuffle: bool, default: True
flag to shuffle dataset
seed: int, default: 0
seed to use for generators
buckets: int
number of buckets to break the dataset into
sorted_n_samples: list
sorted list of samples by the characteristic to bucket by (e.g. seq len)
batch_size: int
batch size that will be used (needed to make sure slices are correct)
allow_bucket_overlap: bool, default: False
allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into
an un-bucketed batch
slice_size: int
computed from batch size and number of replicas
num_samples_per_bucket: int
computed value that represents the number of samples in a single bucket
num_slices_per_bucket: int
computed value that represents the number of slices available in a bucket
bucket_idx: list
computed value that make a contiguous list of indices in each bucket
rounded_num_samples_per_bucket: int
computed value post round for number of samples in a single bucket
rounded_num_samples_per_replica: int
computed value post round for number of slices available in a bucket
"""
def __init__(
self,
dataset: Dataset,
buckets: int,
batch_size: int,
sorted_idx: List,
backend: DistributedOptions,
allow_bucket_overlap: bool = False,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
info_rank: int = 0,
) -> None:
"""Init for BucketedDistributedSampler
Parameters
----------
dataset: Dataset
dataset from which to load the data.
buckets: int
number of buckets to break the dataset into
batch_size: int
batch size that will be used (needed to make sure slices are correct)
sorted_idx: list
sorted list of samples by the characteristic to bucket by (e.g. seq le
backend: DistributedOptions
which backend is being used (as rank, world size, etc. need to be used)
allow_bucket_overlap: bool, default: False
allow for the residual samples (those that are not divisible by batch and num_replicas) to be assembled into
an un-bucketed batch
num_replicas: int, default: None
number of replicas
rank: int, default: None
current device rank
shuffle: bool, default: True
flag to shuffle dataset
seed: int, default: 0
seed to use for generators
drop_last: bool, default: False
whether to drop last set of samples that don't fit into a
info_rank: int, default: 0
which device to print information on
"""
# If the backend isnt DDP there needs to be an additional import
num_replicas, rank = self._conditional_distributed(
backend=backend, num_replicas=num_replicas, rank=rank
)
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
self.buckets = buckets
self.sorted_n_samples = sorted_idx
# Batch size is needed here so a contiguous iter of buckets can be formed
self.batch_size = batch_size
# This is a flag to batch up the dropped samples (that would be 'wasted') if drop_last is flagged
self.allow_bucket_overlap = allow_bucket_overlap
# Calculate the size of each slice that will be indexed across the replicas
self.slice_size = self.batch_size * self.num_replicas
# Calculate the size of the buckets (rounded or not based on drop last)
self.num_samples_per_bucket = self._get_size(
len(dataset), self.buckets, self.drop_last
)
# Calculate the number of slices per bucket
self.num_slices_per_bucket = self._get_size(
self.num_samples_per_bucket, self.slice_size, self.drop_last
)
if self.num_samples_per_bucket < self.slice_size:
raise ValueError(
f"Stoke -- Resulting number of slices (batch * replicas) per bucket "
f"({self.num_samples_per_bucket}) is less than the batch size "
f"({self.batch_size})"
)
if self.num_slices_per_bucket < 2:
raise ValueError(
f"Stoke -- Number of slices per bucket {self.num_slices_per_bucket} is less than 2 "
f"which is not recommended"
)
if self.num_samples_per_bucket < 100:
raise ValueError(
f"Stoke -- Number of samples per bucket {self.num_samples_per_bucket} is less than 100 "
f"which is not recommended as this might lead to dropping of excessive data"
)
# Split into buckets and turn into lists
self.bucket_idx = [
list(val) for val in np.array_split(self.sorted_n_samples, self.buckets)
]
# Calculate the post rounded numbers
self.rounded_num_samples_per_bucket = (
self.slice_size * self.num_slices_per_bucket
)
self.rounded_num_samples_per_replica = (
self.num_slices_per_bucket * self.batch_size * self.buckets
)
# Add the bucket overlap samples
if self.allow_bucket_overlap:
self.rounded_num_samples_per_replica += (
(len(dataset) - (self.rounded_num_samples_per_bucket * self.buckets))
// self.slice_size
) * self.batch_size
if self.rank == info_rank:
print(
f"Stoke -- BucketedDistributedSampler -- # Samples Per Bucket: "
f"{self.rounded_num_samples_per_bucket}, # of Samples Per Replica: "
f"{self.rounded_num_samples_per_replica}"
)
def _conditional_distributed(
self,
backend: DistributedOptions,
num_replicas: Optional[int],
rank: Optional[int],
):
"""
Parameters
----------
backend: DistributedOptions
which backend is being used
num_replicas: int, default: None
total number of replicas
rank: int, default: None
current device rank
Returns
-------
Tuple[int, int]
num_replicas, rank
"""
return self._check_backend(backend, num_replicas, rank)
def _get_backend_functions(self, backend: DistributedOptions):
"""Gets backend functions if needed
Parameters
----------
backend: DistributedOptions
which backend is being used
Returns
-------
Tuple[bool, int, int]
is_init, num_replicas, rank
"""
if backend.value == "ddp" or backend.value == "deepspeed":
return (
torch.distributed.is_initialized,
torch.distributed.get_world_size,
torch.distributed.get_rank,
)
else:
return hvd.is_initialized, hvd.size, hvd.rank
def _check_backend(
self,
backend: DistributedOptions,
num_replicas: Optional[int],
rank: Optional[int],
):
"""Checks the backend for correct device info
Parameters
----------
backend: DistributedOptions
which backend is being used
num_replicas: int, default: None
total number of replicas
rank: int, default: None
current device rank
Returns
-------
Tuple[int, int]
num_replicas, rank
"""
if num_replicas is None or rank is None:
is_avail, get_world_size, get_rank = self._get_backend_functions(
backend=backend
)
if num_replicas is None:
if not is_avail():
raise RuntimeError(
"Requires distributed package (torch.dist or hvd) to be available"
)
num_replicas = get_world_size()
if rank is None:
if not is_avail():
raise RuntimeError(
"Requires distributed package (torch.dist or hvd) to be available"
)
rank = get_rank()
return num_replicas, rank
@staticmethod
def _get_size(data_len: int, split_var: int, drop_last: bool = False):
"""Gets the size of a split
Parameters
----------
data_len: int
current dataset length
split_var: int
how many to split into
drop_last: bool, default: False
drop last hanging samples if not batch_size
Returns
-------
num_samples: int
"""
if drop_last:
num_samples = data_len // split_var
else:
num_samples = ceil(data_len / split_var)
return num_samples
def __iter__(self) -> Iterator[T_co]:
"""Handles assembling the batches from a bucketed perspective
Shuffle bucket order->Pad if necessary->Slice across replicas->Possibly batch up residuals->shuffle bucketed
batches->Unroll into list->Make iter
Returns
-------
Iterator[T_co]
"""
# Shuffle the bucketed idx
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# Permute each bucket
indices = [
[val[idx] for idx in torch.randperm(len(val), generator=g).tolist()]
for val in self.bucket_idx
]
else:
indices = self.bucket_idx
# Iterate over the buckets
for idx, val in enumerate(indices):
# If this is true we need to handle padding
if (self.num_slices_per_bucket * self.slice_size) > len(val):
split_val = self._handle_padding(val)
indices[idx] = list(itertools.chain(*split_val))
assert len(indices[idx]) == self.rounded_num_samples_per_bucket
# Now slice across replicas
final_indices = []
for val in indices:
for idx in range(self.num_slices_per_bucket):
replica_slice = val[
(idx * self.slice_size) : ((idx + 1) * self.slice_size)
][self.rank : self.slice_size : self.num_replicas]
final_indices.append(replica_slice)
# If bucket overlap is allowed then we just batch up the residual indices
if self.drop_last and self.allow_bucket_overlap:
residual_idx = list(
itertools.chain(
*[val[self.rounded_num_samples_per_bucket :] for val in indices]
)
)
if len(residual_idx) > self.slice_size:
# Cut by slices then by replicas
residual_idx = [
residual_idx[
(idx * self.slice_size) : ((idx + 1) * self.slice_size)
][self.rank : self.slice_size : self.num_replicas]
for idx in range(len(residual_idx) // self.slice_size)
]
# Append to the final indices
final_indices.extend(residual_idx)
# Shuffle the bucketed batches
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# Permute the bucket order
final_indices = [
final_indices[val]
for val in torch.randperm(len(final_indices), generator=g)
]
# Unroll into a single list
final_indices = list(itertools.chain(*final_indices))
assert len(final_indices) == self.rounded_num_samples_per_replica
return iter(final_indices)
def _handle_padding(self, idx_list: List):
"""Handles padding out if a batch is short
Parameters
----------
idx_list: List
list of indices
Returns
-------
split_val: List
list with correctly padded sizes
"""
split_val = []
for idx in range(self.num_slices_per_bucket):
if idx == (self.num_slices_per_bucket - 1):
# Get the short batch
short_batch = idx_list[(idx * self.slice_size) :]
# Short batch replica slice sizes
short_len = [
self.batch_size - len(list(val))
for val in np.array_split(short_batch, self.num_replicas)
]
# Pop the necessary values from the entire bucket
pad_values = [
idx_list[s_idx : (self.num_replicas * s_len) : self.num_replicas]
for s_idx, s_len in enumerate(short_len)
]
# If not a consistent list then we need to reorder so that the step size alignment slicing
# of the replicas works
if len(set(short_len)) != 1:
# here we need to find the first larger idx and reorder
first_idx = short_len.index(max(set(short_len)))
# Reorder
pad_values = pad_values[first_idx:] + pad_values[0:first_idx]
extended_batch = short_batch + [
pad
for pad in list(
itertools.chain(*itertools.zip_longest(*pad_values))
)
if pad is not None
]
split_val.append(extended_batch)
else:
split_val.append(
idx_list[(idx * self.slice_size) : ((idx + 1) * self.slice_size)]
)
return split_val
def __len__(self) -> int:
return self.rounded_num_samples_per_replica
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Parameters
----------
epoch: int
Epoch number
"""
self.epoch = epoch
Ancestors (in MRO)
- torch.utils.data.sampler.Sampler
- typing.Generic
Methods
set_epoch
def set_epoch(
self,
epoch: int
) -> None
Sets the epoch for this sampler.
When :attr:shuffle=True
, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch | int | Epoch number | None |
??? example "View Source" def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Parameters
----------
epoch: int
Epoch number
"""
self.epoch = epoch
StokeDataLoader
class StokeDataLoader(
gpu: bool,
fp16: Union[stoke.status.FP16Options, NoneType],
dataset: torch.utils.data.dataset.Dataset[+T_co],
batch_size: Union[int, NoneType] = 1,
shuffle: bool = False,
sampler: Union[torch.utils.data.sampler.Sampler[int], NoneType] = None,
batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence[int]], NoneType] = None,
num_workers: int = 0,
collate_fn: Callable[[List[~T]], Any] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Callable[[int], NoneType] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
_gpu | bool | None | None |
_fp16 | Optional[FP16Options] | None | None |
??? example "View Source" class StokeDataLoader(DL):
"""Provides a shim interface to torch.utils.data.DataLoader with mapped kwargs
Attributes
----------
_gpu: bool
_fp16: Optional[FP16Options]
See Also
--------
torch.utils.data.DataLoader: base DataLoader class that this inherits from (check for all attributes)
"""
def __init__(
self,
gpu: bool,
fp16: Optional[FP16Options],
dataset: Dataset[T_co],
batch_size: Optional[int] = 1,
shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0,
collate_fn: _collate_fn_t = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
):
"""Maps to torch.utils.data.DataLoader __init__
Shim is necessary to automatically handle device placement since the gpu/fp16 flags can't be
determined until the StokeStatus object is available which is post init. This could be disconnected from
this class but it would require the user to forward on device or fp16 configs which breaks the
paradigm that the flags only need to be set and never handled
Parameters
----------
dataset: Dataset
dataset from which to load the data.
batch_size: int, default: 1
how many samples per batch to load .
shuffle: bool, default: False
set to ``True`` to have the data reshuffled at every epoch.
sampler: Sampler or Iterable, default: None
defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler: Sampler or Iterable, default: None:
like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers: int, default: 0
how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.
collate_fn: callable, optional:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory: bool, default: False:
If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your
data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last: bool, default: False
set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
will be smaller.
timeout: numeric, default: 0
if positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn: callable, default: None
If not ``None``, this will be called on each worker subprocess with the worker id
(an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading.
prefetch_factor: int, default: 2
Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers
samples prefetched across all workers.
persistent_workers: bool, default: False
If ``True``, the data loader will not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers `Dataset` instances alive.
Returns
-------
StokeDataLoader
wrapped torch.utils.data.DataLoader object
"""
# Call super init for the actual torch DataLoader
super(StokeDataLoader, self).__init__(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
self._gpu = gpu
self._fp16 = fp16
def __iter__(self):
"""Underlying iter of the DataLoader that yields samples
Wrap the base __iter__ with a call to place on the device if flagged
Yields
------
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]
data placed on the correct device
"""
# Iterate using the base class iter but override the yield by pushing to device prior if gpu flag is true
for val in super().__iter__():
yield val if not self._gpu else self._place_data_on_gpu(val)
def _place_data_on_gpu(
self,
data: Union[
torch.Tensor,
List[torch.Tensor],
Tuple[torch.Tensor],
Dict[str, torch.Tensor],
],
):
"""Determine data structure and then place on the correct device (cast in the context of deepspeed FP16 as it
wants half dtype as input)
Parameters
----------
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]
current data coming from the underlying __iter__
Returns
-------
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, torch.Tensor]]
data moved to the correct device
"""
if isinstance(data, torch.Tensor):
# TODO: Check if one of the APEX version needs a cast too?
# Move to the correct cuda device w/ the correct type -- deepspeed FP16 requires a cast to half if fp16
if self._fp16 == "deepspeed":
return data.to(device="cuda", dtype=torch.half)
else:
return data.to(device="cuda", dtype=data.dtype)
elif isinstance(data, (list, tuple)):
return type(data)(self._place_data_on_gpu(data=val) for val in data)
elif isinstance(data, dict):
return {k: self._place_data_on_gpu(v) for k, v in data.items()}
elif ~(hasattr(data, "to")):
return data
else:
raise TypeError(
f"Stoke -- Unsupported data type passed to _place_data_on_gpu "
f"(torch.Tensor, tuple, list, dict), currently {type(data)}"
)
Ancestors (in MRO)
- torch.utils.data.dataloader.DataLoader
- typing.Generic
Instance variables
multiprocessing_context
Methods
check_worker_number_rationality
def check_worker_number_rationality(
self
)
??? example "View Source" def check_worker_number_rationality(self):
# This function check whether the dataloader's worker number is rational based on
# current system's resource. Current rule is that if the number of workers this
# Dataloader will create is bigger than the number of logical cpus that is allowed to
# use, than we will pop up a warning to let user pay attention.
#
# eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
# threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
# DataLoader process can use half of them which is 32, then the rational max number of
# worker that initiated from this process is 32.
# Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
# So the warning message is triggered to notify the user to lower the worker number if
# necessary.
#
#
# [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
# available (available in most of Linux system, but not OSX and Windows).
# When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
# it doesn't repect cpuset.
# We don't take threading into account since each worker process is single threaded
# at this time.
#
# We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
# other than `torch.set_num_threads` to 1 in the worker process, if the passing
# in functions use 3rd party modules that rely on those threading flags to determine
# how many thread to create (eg. numpy, etc), then it is caller's responsibility to
# set those flags correctly.
def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
suggested_max_worker_msg = ((
"Our suggested max number of worker in current system is {}{}, which is smaller "
"than what this DataLoader is going to create.").format(
num_worker_suggest,
("" if cpuset_checked else " (`cpuset` is not taken into account)"))
) if num_worker_suggest is not None else (
"DataLoader is not able to compute a suggested max number of worker in current system.")
warn_msg = (
"This DataLoader will create {} worker processes in total. {} "
"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
"lower the worker number to avoid potential slowness/freeze if necessary.").format(
num_worker_created,
suggested_max_worker_msg)
return warn_msg
if not self.num_workers or self.num_workers == 0:
return
# try to compute a suggested max number of worker based on system's resource
max_num_worker_suggest = None
cpuset_checked = False
if hasattr(os, 'sched_getaffinity'):
try:
max_num_worker_suggest = len(os.sched_getaffinity(0))
cpuset_checked = True
except Exception:
pass
if max_num_worker_suggest is None:
# os.cpu_count() could return Optional[int]
# get cpu count first and check None in order to satify mypy check
cpu_count = os.cpu_count()
if cpu_count is not None:
max_num_worker_suggest = cpu_count
if max_num_worker_suggest is None:
warnings.warn(_create_warning_msg(
max_num_worker_suggest,
self.num_workers,
cpuset_checked))
return
if self.num_workers > max_num_worker_suggest:
warnings.warn(_create_warning_msg(
max_num_worker_suggest,
self.num_workers,
cpuset_checked))