Tuning¶
Overview¶
An algorithm’s performance is highly dependent on the choice of hyperparameters. Manually tuning the
hyperparameters for algorithms like FictitiousPlay or OnlineMirrorDescent, which have only
one tunable parameter, is not very straightforward, let alone for PriorDescent or MFOMO
which have multiple hyperparameters.
To remedy this, MFGLib algorithms are equipped with automatic hyperparameter tuning based on Optuna 1, an open-source framework designed for efficient hyperparameter search.
To describe how hyperparameter tuning can increase algorithm performance, we must first define what we mean by
“performance”. The MFGLib tuning processes identifies a set of hyperparameters that maximize an objective.
An objective is a function that maps policies, exploitability scores, and runtimes to a real number. In MFGLib,
this objective is represented by the Metric protocol.
Note
Metric is a protocol, not a class. This means any object that implements evaluate()
is a valid Metric. See User-Defined for more details.
Given the metric/objective, an algorithm’s tune() method tests various combinations of hyperparameters
and identifies the optimal choice.
- Algorithm.tune(metric: Metric, envs: Sequence[Environment], pi_0s: Sequence[torch.Tensor] | Literal['uniform'] = 'uniform', solve_kwargs: SolveKwargs | None = None, sampler: optuna.samplers.BaseSampler | None = None, frozen_attrs: Iterable[str] | None = None, n_trials: int | None = None, timeout: float | None = None) optuna.Study¶
Tune the algorithm over multiple environment/initialization pairs.
- Parameters:
metric – Objective function to minimizer.
envs – List of environment targets.
pi_0s – Policy initializations.
envsandpi_0sare “zipped” together when computing the metrics.solve_kwargs – Additional keyword arguments passed to the solver.
sampler – The sampler used to explore the search space of the optimization. If
None, the default sampleroptuna.samplers.TPESampleris used (withseed=0for reproducibility). The sampler guides how different hyperparameter trials are selected.frozen_attrs – A list of attributes that should be frozen (i.e., fixed) during the optimization process. These attributes will not be considered for optimization, and their values will be taken directly from the instance of the class.
n_trials – The number of trials to run. Refer to
optunadocumentation for further details on the handling ofNone.timeout – Stop study after the given number of second(s). Refer to
optunadocumentation for further details.
- Returns:
The result of the hyperparameter tuning process.
- Return type:
optuna.Study
Note
The keyword arguments provided in solve_kwargs will be forwarded to the call to .solve()
during tuning. You can use solve_kwargs to set parameters like atol or rtol.
Notice that tune() returns an optuna.Study object. This object contains the results of the
hyperparameter optimization process. Of course, you can extract the results manually, but for convenience
we provide a from_study() instance method to initialize a new algorithm instance.
- abstract Algorithm.from_study(study: Study) Self¶
Initialize an algorithm instance with tuned hyperparameters.
Examples
>>> from mfglib.alg import PriorDescent >>> from mfglib.env import Environment >>> from mfglib.tuning import GeometricMean >>> >>> prior_descent = PriorDescent(eta=0.1, n_inner=50) >>> study = prior_descent.tune( ... metric=GeometricMean(), ... envs=[Environment.random_linear(T=4, n=3, m=4.0)], ... ) >>> prior_descent_tuned = prior_descent.from_study(study)
Metrics¶
MFGLib ships with two built-in metrics and allows for user-defined metrics too.
Built-In¶
- class FailureRate(fail_thresh: float | int | None = None, stat: Literal['iter', 'rt', 'expl'] = 'expl')¶
The failure rate metric tracks how many “failed” instances were encountered. A call to
solve()is considered a failure depending on the value ofstat.If
stat="iter"then a call tosolve()is considered a failure if the number of iterations reachesfail_thresh.If
stat="rt"then a call tosolve()is considered a failure if the runtime of the call (in seconds) reachesfail_thresh.If
stat="expl"then a call tosolve()is considered a failure if the exploitability score reachesfail_thresh.
- Parameters:
fail_thresh – The failure threshold used to determine a “failed” instance. Can only be
Nonewhenstat="iter"orstat="expl", in which case the solver’smax_iteroratol, respectively, is used as the default threshold.stat – The statistic to monitor during optimization.
- class GeometricMean(shift: float = 0, stat: Literal['iter', 'rt', 'expl'] = 'expl')¶
- Parameters:
shift – An additional shift value for the geometric mean. Defaults to zero.
stat – The statistic to be used in the mean.
User-Defined¶
To implement a user-defined method, you simply need to create an object which implements an
evaluate() method with the following signature.
- class Metric.evaluate(self, pis: list[list[Tensor]], expls: list[list[float]], rts: list[list[float]], solve_kwargs: SolveKwargs)
Compute scalar measure of the given solution trace.
- Parameters:
pis (list[list[torch.Tensor]]) – A list-of-list of policies. Each item in the outer list represents an (environment, policy) pair, and each item in the inner list corresponds with an algorithm iteration.
expls (list[list[float]]) – A list-of-list of exploitability scores. Each item in the outer list represents an (environment, policy) pair, and each item in the inner list corresponds with an algorithm iteration.
rts (list[list[float]]) – A list-of-list of runtimes. Each item in the outer list represents an (environment, policy) pair, and each item in the inner list corresponds with an algorithm iteration.
solve_kwargs (mfglib.alg.abc.SolveKwargs) – Additional arguments passed to
solve().
Let’s implement a simple metric – the highest exploitability encountered across all environment/policy pairs.
We’ll call it the Linfty metric, since it roughly corresponds to the \(\ell_\infty\) norm.
from mfglib.alg.abc import SolveKwargs
class Linfty:
def evaluate(
pis: list[list[torch.Tensor]],
expls: list[list[float]],
rts: list[list[float]],
solve_kwargs: SolveKwargs
) -> float:
max_expl = 0
for env_expl_list in expls:
max_expl_for_env = max(env_expl_list)
max_expl = max(max_expl, max_expl_for_env)
return max_expl