Module stoke.extensions
Handles extension wrapper related classes -- mixin style
None
??? example "View Source" # -- coding: utf-8 --
# Copyright FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0
"""Handles extension wrapper related classes -- mixin style"""
from abc import ABC
from enum import Enum
from typing import Dict, Optional, Tuple, Type, Union
import attr
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel, ShardedDataParallel
from fairscale.optim.oss import OSS
from stoke.configs import (
DDPConfig,
FairscaleFSDPConfig,
FairscaleOSSConfig,
FairscaleSDDPConfig,
)
@attr.s(auto_attribs=True)
class _FairscaleFSDPConfig(FairscaleFSDPConfig):
mixed_precision: bool = False
class BaseOptimizer(ABC):
"""Base class for creating an optimizer
Attributes
----------
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(self, verbose: bool = True, **kwargs):
"""Init for BaseOptimizer class
Parameters
----------
verbose: bool, default: True
flag for verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
def build_optimizer(
self,
optimizer: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.Module,
) -> torch.optim.Optimizer:
"""Instantiates a torch optimizer object from the type and optimizer kwargs
Parameters
----------
optimizer: Type[torch.optim.Optimizer]
type of torch optimizer
optimizer_kwargs: Dict
dictionary of all kwargs to pass to the optimizer
model: torch.nn.Module
model object
Returns
-------
torch.optim.Optimizer
instantiated torch optimizer object
"""
if self._verbose:
self._print_device(f"Creating basic torch optimizer: {optimizer.__name__}")
return optimizer(params=model.parameters(), **optimizer_kwargs)
class FairscaleOSSExtension(BaseOptimizer):
"""Inherits from BaseOptimizer for OSS class creation
Attributes
----------
_oss_config: FairscaleOSSConfig,
Configuration object for Fairscale OSS
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(self, oss_config: FairscaleOSSConfig, verbose: bool = True, **kwargs):
"""Init for FairscaleOSSExtension class
Parameters
----------
oss_config: FairscaleOSSConfig
Configuration object for Fairscale OSS
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
super(FairscaleOSSExtension, self).__init__(verbose=verbose)
self._oss_config = oss_config
def build_optimizer(
self,
optimizer: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.Module,
) -> OSS:
"""Instantiates a Fairscale OSS optimizer object from the type and optimizer kwargs
Parameters
----------
optimizer: Type[torch.optim.Optimizer]
type of torch optimizer
optimizer_kwargs: Dict
dictionary of all kwargs to pass to the optimizer
model: torch.nn.Module
model object
Returns
-------
OSS
instantiated Fairscale OSS optimizer object
"""
if self._verbose:
self._print_device(
f"Creating Fairscale OSS wrapped PyTorch optimizer: {optimizer.__name__}"
)
return OSS(
params=model.parameters(),
optim=optimizer,
broadcast_fp16=self._oss_config.broadcast_fp16,
**optimizer_kwargs,
)
class RunnerOptimizerEnum(Enum):
"""Enum for optimizer creation"""
oss = FairscaleOSSExtension
base = BaseOptimizer
class BaseDDP:
"""Base class for using the DDP backend
Attributes
----------
_ddp_config: DDPConfig
Base DDP configuration object
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(self, ddp_config: DDPConfig, verbose: bool = True, **kwargs):
"""Init for BaseDDP
Parameters
----------
ddp_config: DDPConfig
Base DDP configuration object
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
self._ddp_config = ddp_config
def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the base DDP call
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = torch.nn.parallel.DistributedDataParallel(
module=model,
device_ids=[rank],
output_device=rank,
bucket_cap_mb=self._ddp_config.bucket_cap_mb,
broadcast_buffers=self._ddp_config.broadcast_buffers,
find_unused_parameters=self._ddp_config.find_unused_parameters,
gradient_as_bucket_view=self._ddp_config.gradient_as_bucket_view,
)
return model, optimizer
class FairscaleSDDPExtension:
"""Class for using the Fairscale SDDP backend
Attributes
----------
_sddp_config: FairscaleSDDPConfig
Base Fairscale ShardedDataParallel configuration object
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(
self, sddp_config: FairscaleSDDPConfig, verbose: bool = True, **kwargs
):
"""Init for FairscaleSDDPExtension
Parameters
----------
sddp_config: FairscaleSDDPConfig
Base Fairscale ShardedDataParallel configuration objet
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
self._sddp_config = sddp_config
def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the ShardedDataParallel call
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = ShardedDataParallel(
module=model,
sharded_optimizer=optimizer,
broadcast_buffers=self._sddp_config.broadcast_buffers,
sync_models_at_startup=self._sddp_config.sync_models_at_startup,
reduce_buffer_size=self._sddp_config.reduce_buffer_size,
auto_refresh_trainable=self._sddp_config.auto_refresh_trainable,
reduce_fp16=self._sddp_config.reduce_fp16,
)
return model, optimizer
class FairscaleFSDPExtension:
"""Class for using the Fairscale FSDP backend
Attributes
----------
_fsdp_config: _FairscaleFSDPConfig
Base Fairscale Fully Sharded Data Parallel configuration object
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(
self, fsdp_config: _FairscaleFSDPConfig, verbose: bool = True, **kwargs
):
"""Init for FairscaleSDDPExtension
Parameters
----------
_fsdp_config: _FairscaleFSDPConfig
Base Fairscale Fully Sharded Data Parallel configuration object
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
self._fsdpp_config = fsdp_config
def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the FullyShardedDataParallel call
Also sets grad divide factors
https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.set_gradient_divide_factors
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = FullyShardedDataParallel(
module=model,
reshard_after_forward=self._fsdpp_config.reshard_after_forward,
mixed_precision=self._fsdpp_config.mixed_precision,
fp32_reduce_scatter=self._fsdpp_config.fp32_reduce_scatter,
flatten_parameters=self._fsdpp_config.flatten_parameters,
move_params_to_cpu=self._fsdpp_config.move_params_to_cpu,
compute_dtype=self._fsdpp_config.compute_dtype,
buffer_dtype=self._fsdpp_config.buffer_dtype,
move_grads_to_cpu=self._fsdpp_config.move_grads_to_cpu,
bucket_cap_mb=self._fsdpp_config.bucket_cap_mb,
no_broadcast_optim_state=self._fsdpp_config.no_broadcast_optim_state,
clear_autocast_cache=self._fsdpp_config.clear_autocast_cache,
force_input_to_fp32=self._fsdpp_config.force_input_to_fp32,
verbose=self._fsdpp_config.verbose,
)
# Trigger the set of pre-divide or post-divide factors if set in the config
model.set_gradient_divide_factors(
pre=self._fsdpp_config.gradient_predivide_factor
if self._fsdpp_config.gradient_predivide_factor is not None
else model.gradient_predivide_factor,
post=self._fsdpp_config.gradient_postdivide_factor
if self._fsdpp_config.gradient_postdivide_factor is not None
else model.gradient_postdivide_factor,
recursive=True,
)
return model, optimizer
class DistributedHandlerEnum(Enum):
"""Enum for DDP use"""
sddp = FairscaleSDDPExtension
fsdp = FairscaleFSDPExtension
base = BaseDDP
Classes
BaseDDP
class BaseDDP(
ddp_config: stoke.configs.DDPConfig,
verbose: bool = True,
**kwargs
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
_ddp_config | DDPConfig | Base DDP configuration object | None |
_verbose | bool, default: True | flag for Stoke print verbosity | None |
??? example "View Source" class BaseDDP:
"""Base class for using the DDP backend
Attributes
----------
_ddp_config: DDPConfig
Base DDP configuration object
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(self, ddp_config: DDPConfig, verbose: bool = True, **kwargs):
"""Init for BaseDDP
Parameters
----------
ddp_config: DDPConfig
Base DDP configuration object
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
self._ddp_config = ddp_config
def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the base DDP call
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = torch.nn.parallel.DistributedDataParallel(
module=model,
device_ids=[rank],
output_device=rank,
bucket_cap_mb=self._ddp_config.bucket_cap_mb,
broadcast_buffers=self._ddp_config.broadcast_buffers,
find_unused_parameters=self._ddp_config.find_unused_parameters,
gradient_as_bucket_view=self._ddp_config.gradient_as_bucket_view,
)
return model, optimizer
Methods
handle_ddp
def handle_ddp(
self,
model: torch.nn.modules.module.Module,
optimizer: Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS],
grad_accum: Union[int, NoneType],
rank: int
) -> Tuple[torch.nn.modules.module.Module, Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS]]
Wraps the model in the base DDP call
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model | torch.nn.Module | Current model object | None |
optimizer | Union[torch.optim.Optimizer, OSS] | Current optimizer object | None |
grad_accum | int, default: None | Number of gradient accumulation steps | None |
rank | int | Current CUDA device rank in the distributed setup | None |
Returns:
Type | Description |
---|---|
torch.nn.Module | Wrapped model object |
??? example "View Source" def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the base DDP call
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = torch.nn.parallel.DistributedDataParallel(
module=model,
device_ids=[rank],
output_device=rank,
bucket_cap_mb=self._ddp_config.bucket_cap_mb,
broadcast_buffers=self._ddp_config.broadcast_buffers,
find_unused_parameters=self._ddp_config.find_unused_parameters,
gradient_as_bucket_view=self._ddp_config.gradient_as_bucket_view,
)
return model, optimizer
BaseOptimizer
class BaseOptimizer(
verbose: bool = True,
**kwargs
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
_verbose | bool, default: True | flag for Stoke print verbosity | None |
??? example "View Source" class BaseOptimizer(ABC):
"""Base class for creating an optimizer
Attributes
----------
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(self, verbose: bool = True, **kwargs):
"""Init for BaseOptimizer class
Parameters
----------
verbose: bool, default: True
flag for verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
def build_optimizer(
self,
optimizer: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.Module,
) -> torch.optim.Optimizer:
"""Instantiates a torch optimizer object from the type and optimizer kwargs
Parameters
----------
optimizer: Type[torch.optim.Optimizer]
type of torch optimizer
optimizer_kwargs: Dict
dictionary of all kwargs to pass to the optimizer
model: torch.nn.Module
model object
Returns
-------
torch.optim.Optimizer
instantiated torch optimizer object
"""
if self._verbose:
self._print_device(f"Creating basic torch optimizer: {optimizer.__name__}")
return optimizer(params=model.parameters(), **optimizer_kwargs)
Ancestors (in MRO)
- abc.ABC
Descendants
- stoke.extensions.FairscaleOSSExtension
Methods
build_optimizer
def build_optimizer(
self,
optimizer: Type[torch.optim.optimizer.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.modules.module.Module
) -> torch.optim.optimizer.Optimizer
Instantiates a torch optimizer object from the type and optimizer kwargs
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer | Type[torch.optim.Optimizer] | type of torch optimizer | None |
optimizer_kwargs | Dict | dictionary of all kwargs to pass to the optimizer | None |
model | torch.nn.Module | model object | None |
Returns:
Type | Description |
---|---|
torch.optim.Optimizer | instantiated torch optimizer object |
??? example "View Source" def build_optimizer(
self,
optimizer: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.Module,
) -> torch.optim.Optimizer:
"""Instantiates a torch optimizer object from the type and optimizer kwargs
Parameters
----------
optimizer: Type[torch.optim.Optimizer]
type of torch optimizer
optimizer_kwargs: Dict
dictionary of all kwargs to pass to the optimizer
model: torch.nn.Module
model object
Returns
-------
torch.optim.Optimizer
instantiated torch optimizer object
"""
if self._verbose:
self._print_device(f"Creating basic torch optimizer: {optimizer.__name__}")
return optimizer(params=model.parameters(), **optimizer_kwargs)
DistributedHandlerEnum
class DistributedHandlerEnum(
/,
*args,
**kwargs
)
??? example "View Source" class DistributedHandlerEnum(Enum):
"""Enum for DDP use"""
sddp = FairscaleSDDPExtension
fsdp = FairscaleFSDPExtension
base = BaseDDP
Ancestors (in MRO)
- enum.Enum
Class variables
base
fsdp
name
sddp
value
FairscaleFSDPExtension
class FairscaleFSDPExtension(
fsdp_config: stoke.extensions._FairscaleFSDPConfig,
verbose: bool = True,
**kwargs
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
_fsdp_config | _FairscaleFSDPConfig | Base Fairscale Fully Sharded Data Parallel configuration object | None |
_verbose | bool, default: True | flag for Stoke print verbosity | None |
??? example "View Source" class FairscaleFSDPExtension:
"""Class for using the Fairscale FSDP backend
Attributes
----------
_fsdp_config: _FairscaleFSDPConfig
Base Fairscale Fully Sharded Data Parallel configuration object
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(
self, fsdp_config: _FairscaleFSDPConfig, verbose: bool = True, **kwargs
):
"""Init for FairscaleSDDPExtension
Parameters
----------
_fsdp_config: _FairscaleFSDPConfig
Base Fairscale Fully Sharded Data Parallel configuration object
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
self._fsdpp_config = fsdp_config
def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the FullyShardedDataParallel call
Also sets grad divide factors
https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.set_gradient_divide_factors
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = FullyShardedDataParallel(
module=model,
reshard_after_forward=self._fsdpp_config.reshard_after_forward,
mixed_precision=self._fsdpp_config.mixed_precision,
fp32_reduce_scatter=self._fsdpp_config.fp32_reduce_scatter,
flatten_parameters=self._fsdpp_config.flatten_parameters,
move_params_to_cpu=self._fsdpp_config.move_params_to_cpu,
compute_dtype=self._fsdpp_config.compute_dtype,
buffer_dtype=self._fsdpp_config.buffer_dtype,
move_grads_to_cpu=self._fsdpp_config.move_grads_to_cpu,
bucket_cap_mb=self._fsdpp_config.bucket_cap_mb,
no_broadcast_optim_state=self._fsdpp_config.no_broadcast_optim_state,
clear_autocast_cache=self._fsdpp_config.clear_autocast_cache,
force_input_to_fp32=self._fsdpp_config.force_input_to_fp32,
verbose=self._fsdpp_config.verbose,
)
# Trigger the set of pre-divide or post-divide factors if set in the config
model.set_gradient_divide_factors(
pre=self._fsdpp_config.gradient_predivide_factor
if self._fsdpp_config.gradient_predivide_factor is not None
else model.gradient_predivide_factor,
post=self._fsdpp_config.gradient_postdivide_factor
if self._fsdpp_config.gradient_postdivide_factor is not None
else model.gradient_postdivide_factor,
recursive=True,
)
return model, optimizer
Methods
handle_ddp
def handle_ddp(
self,
model: torch.nn.modules.module.Module,
optimizer: Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS],
grad_accum: Union[int, NoneType],
rank: int
) -> Tuple[torch.nn.modules.module.Module, Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS]]
Wraps the model in the FullyShardedDataParallel call
Also sets grad divide factors https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.set_gradient_divide_factors
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model | torch.nn.Module | Current model object | None |
optimizer | Union[torch.optim.Optimizer, OSS] | Current optimizer object | None |
grad_accum | int, default: None | Number of gradient accumulation steps | None |
rank | int | Current CUDA device rank in the distributed setup | None |
Returns:
Type | Description |
---|---|
torch.nn.Module | Wrapped model object |
??? example "View Source" def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the FullyShardedDataParallel call
Also sets grad divide factors
https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.set_gradient_divide_factors
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = FullyShardedDataParallel(
module=model,
reshard_after_forward=self._fsdpp_config.reshard_after_forward,
mixed_precision=self._fsdpp_config.mixed_precision,
fp32_reduce_scatter=self._fsdpp_config.fp32_reduce_scatter,
flatten_parameters=self._fsdpp_config.flatten_parameters,
move_params_to_cpu=self._fsdpp_config.move_params_to_cpu,
compute_dtype=self._fsdpp_config.compute_dtype,
buffer_dtype=self._fsdpp_config.buffer_dtype,
move_grads_to_cpu=self._fsdpp_config.move_grads_to_cpu,
bucket_cap_mb=self._fsdpp_config.bucket_cap_mb,
no_broadcast_optim_state=self._fsdpp_config.no_broadcast_optim_state,
clear_autocast_cache=self._fsdpp_config.clear_autocast_cache,
force_input_to_fp32=self._fsdpp_config.force_input_to_fp32,
verbose=self._fsdpp_config.verbose,
)
# Trigger the set of pre-divide or post-divide factors if set in the config
model.set_gradient_divide_factors(
pre=self._fsdpp_config.gradient_predivide_factor
if self._fsdpp_config.gradient_predivide_factor is not None
else model.gradient_predivide_factor,
post=self._fsdpp_config.gradient_postdivide_factor
if self._fsdpp_config.gradient_postdivide_factor is not None
else model.gradient_postdivide_factor,
recursive=True,
)
return model, optimizer
FairscaleOSSExtension
class FairscaleOSSExtension(
oss_config: stoke.configs.FairscaleOSSConfig,
verbose: bool = True,
**kwargs
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
_oss_config | FairscaleOSSConfig, | Configuration object for Fairscale OSS | None |
_verbose | bool, default: True | flag for Stoke print verbosity | None |
??? example "View Source" class FairscaleOSSExtension(BaseOptimizer):
"""Inherits from BaseOptimizer for OSS class creation
Attributes
----------
_oss_config: FairscaleOSSConfig,
Configuration object for Fairscale OSS
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(self, oss_config: FairscaleOSSConfig, verbose: bool = True, **kwargs):
"""Init for FairscaleOSSExtension class
Parameters
----------
oss_config: FairscaleOSSConfig
Configuration object for Fairscale OSS
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
super(FairscaleOSSExtension, self).__init__(verbose=verbose)
self._oss_config = oss_config
def build_optimizer(
self,
optimizer: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.Module,
) -> OSS:
"""Instantiates a Fairscale OSS optimizer object from the type and optimizer kwargs
Parameters
----------
optimizer: Type[torch.optim.Optimizer]
type of torch optimizer
optimizer_kwargs: Dict
dictionary of all kwargs to pass to the optimizer
model: torch.nn.Module
model object
Returns
-------
OSS
instantiated Fairscale OSS optimizer object
"""
if self._verbose:
self._print_device(
f"Creating Fairscale OSS wrapped PyTorch optimizer: {optimizer.__name__}"
)
return OSS(
params=model.parameters(),
optim=optimizer,
broadcast_fp16=self._oss_config.broadcast_fp16,
**optimizer_kwargs,
)
Ancestors (in MRO)
- stoke.extensions.BaseOptimizer
- abc.ABC
Methods
build_optimizer
def build_optimizer(
self,
optimizer: Type[torch.optim.optimizer.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.modules.module.Module
) -> fairscale.optim.oss.OSS
Instantiates a Fairscale OSS optimizer object from the type and optimizer kwargs
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer | Type[torch.optim.Optimizer] | type of torch optimizer | None |
optimizer_kwargs | Dict | dictionary of all kwargs to pass to the optimizer | None |
model | torch.nn.Module | model object | None |
Returns:
Type | Description |
---|---|
OSS | instantiated Fairscale OSS optimizer object |
??? example "View Source" def build_optimizer(
self,
optimizer: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict,
model: torch.nn.Module,
) -> OSS:
"""Instantiates a Fairscale OSS optimizer object from the type and optimizer kwargs
Parameters
----------
optimizer: Type[torch.optim.Optimizer]
type of torch optimizer
optimizer_kwargs: Dict
dictionary of all kwargs to pass to the optimizer
model: torch.nn.Module
model object
Returns
-------
OSS
instantiated Fairscale OSS optimizer object
"""
if self._verbose:
self._print_device(
f"Creating Fairscale OSS wrapped PyTorch optimizer: {optimizer.__name__}"
)
return OSS(
params=model.parameters(),
optim=optimizer,
broadcast_fp16=self._oss_config.broadcast_fp16,
**optimizer_kwargs,
)
FairscaleSDDPExtension
class FairscaleSDDPExtension(
sddp_config: stoke.configs.FairscaleSDDPConfig,
verbose: bool = True,
**kwargs
)
Attributes
Name | Type | Description | Default |
---|---|---|---|
_sddp_config | FairscaleSDDPConfig | Base Fairscale ShardedDataParallel configuration object | None |
_verbose | bool, default: True | flag for Stoke print verbosity | None |
??? example "View Source" class FairscaleSDDPExtension:
"""Class for using the Fairscale SDDP backend
Attributes
----------
_sddp_config: FairscaleSDDPConfig
Base Fairscale ShardedDataParallel configuration object
_verbose: bool, default: True
flag for Stoke print verbosity
"""
def __init__(
self, sddp_config: FairscaleSDDPConfig, verbose: bool = True, **kwargs
):
"""Init for FairscaleSDDPExtension
Parameters
----------
sddp_config: FairscaleSDDPConfig
Base Fairscale ShardedDataParallel configuration objet
verbose: bool, default: True
flag for Stoke print verbosity
**kwargs: dict, optional
Extra arguments passed to the __init__ call
"""
self._verbose = verbose
self._sddp_config = sddp_config
def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the ShardedDataParallel call
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = ShardedDataParallel(
module=model,
sharded_optimizer=optimizer,
broadcast_buffers=self._sddp_config.broadcast_buffers,
sync_models_at_startup=self._sddp_config.sync_models_at_startup,
reduce_buffer_size=self._sddp_config.reduce_buffer_size,
auto_refresh_trainable=self._sddp_config.auto_refresh_trainable,
reduce_fp16=self._sddp_config.reduce_fp16,
)
return model, optimizer
Methods
handle_ddp
def handle_ddp(
self,
model: torch.nn.modules.module.Module,
optimizer: Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS],
grad_accum: Union[int, NoneType],
rank: int
) -> Tuple[torch.nn.modules.module.Module, Union[torch.optim.optimizer.Optimizer, fairscale.optim.oss.OSS]]
Wraps the model in the ShardedDataParallel call
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model | torch.nn.Module | Current model object | None |
optimizer | Union[torch.optim.Optimizer, OSS] | Current optimizer object | None |
grad_accum | int, default: None | Number of gradient accumulation steps | None |
rank | int | Current CUDA device rank in the distributed setup | None |
Returns:
Type | Description |
---|---|
torch.nn.Module | Wrapped model object |
??? example "View Source" def handle_ddp(
self,
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OSS],
grad_accum: Optional[int],
rank: int,
) -> Tuple[torch.nn.Module, Union[torch.optim.Optimizer, OSS]]:
"""Wraps the model in the ShardedDataParallel call
Parameters
----------
model: torch.nn.Module
Current model object
optimizer: Union[torch.optim.Optimizer, OSS]
Current optimizer object
grad_accum: int, default: None
Number of gradient accumulation steps
rank: int
Current CUDA device rank in the distributed setup
Returns
-------
model: torch.nn.Module
Wrapped model object
optimizer: Union[torch.optim.Optimizer, OSS]
current optimizer object
"""
model = ShardedDataParallel(
module=model,
sharded_optimizer=optimizer,
broadcast_buffers=self._sddp_config.broadcast_buffers,
sync_models_at_startup=self._sddp_config.sync_models_at_startup,
reduce_buffer_size=self._sddp_config.reduce_buffer_size,
auto_refresh_trainable=self._sddp_config.auto_refresh_trainable,
reduce_fp16=self._sddp_config.reduce_fp16,
)
return model, optimizer
RunnerOptimizerEnum
class RunnerOptimizerEnum(
/,
*args,
**kwargs
)
??? example "View Source" class RunnerOptimizerEnum(Enum):
"""Enum for optimizer creation"""
oss = FairscaleOSSExtension
base = BaseOptimizer
Ancestors (in MRO)
- enum.Enum
Class variables
base
name
oss
value