Optuna Support
spock
integrates with the Optuna hyper-parameter optimization framework through the provided
ask-and-run interface and the define-and-run API. See docs.
All examples can be found here.
#
Defining the BackendSo let's continue in our Optuna specific version of tune.py
:
It's important to note that you can still use the @spock
decorator to define any non hyper-parameters! For
posterity let's add some fixed parameters (those that are not part of hyper-parameter tuning) that we will use
elsewhere in our code.
from spock import spock
from spock.addons.tune import ( ChoiceHyperParameter, RangeHyperParameter, spockTuner,)
@spockclass BasicParams: n_trials: int max_iter: int
@spockTunerclass LogisticRegressionHP: c: RangeHyperParameter solver: ChoiceHyperParameter
Now we need to tell spock
that we intend on doing hyper-parameter tuning and which backend we would like to use. We
do this by calling the tuner
method on the SpockBuilder
object passing in a configuration object for the
backend of choice (just like in basic functionality this is a chained command, thus the builder object will still be
returned). For Optuna one uses OptunaTunerConfig
. This config mirrors all options that would be passed into
the optuna.study.create_study
function call so that spock
can setup the define-and-run API. (Note: The @spockTuner
decorated classes are passed to the SpockBuilder
in the exact same way as basic @spock
decorated classes.)
from spock import SpockBuilderfrom spock.addons.tune import OptunaTunerConfig
# Optuna config -- this will internally configure the study object for the define-and-run style which will be returned# by accessing the tuner_status property on the SpockBuilder objectoptuna_config = OptunaTunerConfig( study_name="Iris Logistic Regression", direction="maximize")
# Use the builder to setup# Call tuner to indicate that we are going to do some HP tuning -- passing in an optuna study objectattrs_obj = SpockBuilder( LogisticRegressionHP, BasicParams, desc="Example Logistic Regression Hyper-Parameter Tuning -- Optuna Backend",).tuner(tuner_config=optuna_config)
#
Generate Functionality Still ExistsTo get the set of fixed parameters (those that are not hyper-parameters) one simply calls the generate()
function
just like they would for normal spock
usage to get the fixed parameter spockspace
.
Continuing in tune.py
:
# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params# prior to starting the sampling processfixed_params = attrs_obj.generate()
#
Sample as an Alternative to GenerateThe sample()
call is the crux of spock
hyper-parameter tuning support. It draws a hyper-parameter sample from the
underlying backend sampler and combines it with fixed parameters and returns a single Spockspace
with all
usable parameters (defined with dot notation). For Optuna -- Under the hood spock
uses the define-and-run Optuna
interface -- thus it handles the underlying 'ask' call. The spock
builder object has a @property
called
tuner_status
that returns any necessary backend objects in a dictionary that the user needs to interface with. In the
case of Optuna, this contains both the Optuna study
and trial
(as dictionary keys). We use the return of
tuner_status
to handle the 'tell' call based on the metric of interested (here just simple validation accuracy)
Continuing in tune.py
:
# Iterate through a bunch of optuna trialsfor _ in range(fixed_params.BasicParams.n_trials): # Call sample on the spock object hp_attrs = attrs_obj.sample() # Use the currently sampled parameters in a simple LogisticRegression from sklearn clf = LogisticRegression( C=hp_attrs.LogisticRegressionHP.c, solver=hp_attrs.LogisticRegressionHP.solver, max_iter=hp_attrs.BasicParams.max_iter ) clf.fit(X_train, y_train) val_acc = clf.score(X_valid, y_valid) # Get the status of the tuner -- this dict will contain all the objects needed to update tuner_status = attrs_obj.tuner_status # Pull the study and trials object out of the return dictionary and pass it to the tell call using the study # object tuner_status["study"].tell(tuner_status["trial"], val_acc)