Skip to content

Home

Stoke

Add a little accelerant to your torch

License Python Style Lint Docs


About

stoke is a lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices (e.g. CPU, GPU), distributed modes, mixed-precision, and PyTorch extensions. This allows you to switch from local full-precision CPU to mixed-precision distributed multi-GPU with extensions (like optimizer state sharding) by simply changing a few declarative flags. Additionally, stoke exposes configuration settings for every underlying backend for those that want configurability and raw access to the underlying libraries.

In short, stoke is the best of PyTorch Lightning Accelerators disconnected from the rest of PyTorch Lightning. Write whatever PyTorch code you want, but leave device and backend context switching to stoke.

Supports

Benefits/Capabilities

  • Declarative style API -- allows you to declare or specify the desired state and let stoke handle the rest
  • Mirrors base PyTorch style model, loss, backward, and step calls
  • Automatic device placement of model(s) and data
  • Universal interface for saving and loading regardless of backend(s) or device
  • Automatic handling of gradient accumulation and clipping
  • Common attrs interface for all backend configuration parameters (with docstrings)
  • Helper methods for printing synced losses, device specific print, number of model parameters
  • Extra(s) - Custom torch.utils.data.distributed.Sampler: BucketedDistributedSampler which buckets data by a sorted idx and then randomly samples from specific bucket(s) to prevent situations like grossly mismatched sequence length leading to wasted computational overhead (ie excess padding)
Back to top