Ax Support
spock integrates with the Ax optimization framework through the provided Service API. See
docs for AxClient info.
All examples can be found here.
Defining the Backend#
So let's continue with our Ax specific version of tune.py:
It's important to note that you can still use the @spock decorator to define any non hyper-parameters! For
posterity let's add some fixed parameters (those that are not part of hyper-parameter tuning) that we will use
elsewhere in our code.
from spock import spock
from spock.addons.tune import ( ChoiceHyperParameter, RangeHyperParameter, spockTuner,)
@spockclass BasicParams: n_trials: int max_iter: int
@spockTunerclass LogisticRegressionHP: c: RangeHyperParameter solver: ChoiceHyperParameterNow we need to tell spock that we intend on doing hyper-parameter tuning and which backend we would like to use. We
do this by calling the tuner method on the SpockBuilder object passing in a configuration object for the
backend of choice (just like in basic functionality this is a chained command, thus the builder object will still be
returned). For Ax one uses AxTunerConfig. This config mirrors all options that would be passed into
the AxClient constructor and the AxClient.create_experimentfunction call so that spock can setup the
Service API. (Note: The @spockTunerdecorated classes are passed to the SpockBuilder in the exact same
way as basic @spockdecorated classes.)
from spock import SpockBuilderfrom spock.addons.tune import AxTunerConfig
# Ax config -- this will internally spawn the AxClient service API style which will be returned# by accessing the tuner_status property on the SpockBuilder object -- note here that we need to define the# objective name that the client will expect to be within the data dictionary when completing trials ax_config = AxTunerConfig(objective_name="accuracy", minimize=False)
# Use the builder to setup# Call tuner to indicate that we are going to do some HP tuning -- passing in an ax study objectattrs_obj = SpockBuilder( LogisticRegressionHP, BasicParams, desc="Example Logistic Regression Hyper-Parameter Tuning -- Ax Backend",).tuner(tuner_config=ax_config)
Generate Functionality Still Exists#
To get the set of fixed parameters (those that are not hyper-parameters) one simply calls the generate() function
just like they would for normal spock usage to get the fixed parameter spockspace.
Continuing in tune.py:
# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params# prior to starting the sampling processfixed_params = attrs_obj.generate()Sample as an Alternative to Generate#
The sample() call is the crux of spock hyper-parameter tuning support. It draws a hyper-parameter sample from the
underlying backend sampler and combines it with fixed parameters and returns a single Spockspace with all
usable parameters (defined with dot notation). For Ax -- Under the hood spock uses the Service API (with
an AxClient) -- thus it handles the underlying call to get the next trial. The spock builder object has a
@property called tuner_status that returns any necessary backend objects in a dictionary that the user needs to
interface with. In the case of Ax, this contains both the AxClient and trial_index (as dictionary keys). We use
the return oftuner_status to handle trial completion via the complete_trial call based on the metric of interested
(here just the simple validation accuracy -- remember during AxTunerConfig instantiation we set the objective_name
to 'accuracy' -- we also set the SEM to 0.0 since we are not using it for this example)
See here for Ax documentation on completing trials.
Continuing in tune.py:
# Iterate through a bunch of ax trialsfor _ in range(fixed_params.BasicParams.n_trials): # Call sample on the spock object hp_attrs = attrs_obj.sample() # Use the currently sampled parameters in a simple LogisticRegression from sklearn clf = LogisticRegression( C=hp_attrs.LogisticRegressionHP.c, solver=hp_attrs.LogisticRegressionHP.solver, max_iter=hp_attrs.BasicParams.max_iter ) clf.fit(X_train, y_train) val_acc = clf.score(X_valid, y_valid) # Get the status of the tuner -- this dict will contain all the objects needed to update tuner_status = attrs_obj.tuner_status # Pull the AxClient object and trial index out of the return dictionary and call 'complete_trial' on the # AxClient object with the correct raw_data that contains the objective name tuner_status["client"].complete_trial( trial_index=tuner_status["trial_index"], raw_data={"accuracy": (val_acc, 0.0)}, )