Tuning

Overview

An algorithm’s performance is highly dependent on the choice of hyperparameters. Manually tuning the hyperparameters for algorithms like FictitiousPlay or OnlineMirrorDescent, which have only one tunable parameter, is not very straightforward, let alone for PriorDescent or MFOMO which have multiple hyperparameters.

To remedy this, MFGLib algorithms are equipped with automatic hyperparameter tuning based on Optuna 1, an open-source framework designed for efficient hyperparameter search.

To describe how hyperparameter tuning can increase algorithm performance, we must first define what we mean by “performance”. The MFGLib tuning processes identifies a set of hyperparameters that maximize an objective. An objective is a function that maps policies, exploitability scores, and runtimes to a real number. In MFGLib, this objective is represented by the Metric protocol.

Note

Metric is a protocol, not a class. This means any object that implements evaluate() is a valid Metric. See User-Defined for more details.

Given the metric/objective, an algorithm’s tune() method tests various combinations of hyperparameters and identifies the optimal choice.

Algorithm.tune(metric: Metric, envs: Sequence[Environment], pi_0s: Sequence[torch.Tensor] | Literal['uniform'] = 'uniform', solve_kwargs: SolveKwargs | None = None, sampler: optuna.samplers.BaseSampler | None = None, frozen_attrs: Iterable[str] | None = None, n_trials: int | None = None, timeout: float | None = None) optuna.Study

Tune the algorithm over multiple environment/initialization pairs.

Parameters:
  • metric – Objective function to minimizer.

  • envs – List of environment targets.

  • pi_0s – Policy initializations. envs and pi_0s are “zipped” together when computing the metrics.

  • solve_kwargs – Additional keyword arguments passed to the solver.

  • sampler – The sampler used to explore the search space of the optimization. If None, the default sampler optuna.samplers.TPESampler is used (with seed=0 for reproducibility). The sampler guides how different hyperparameter trials are selected.

  • frozen_attrs – A list of attributes that should be frozen (i.e., fixed) during the optimization process. These attributes will not be considered for optimization, and their values will be taken directly from the instance of the class.

  • n_trials – The number of trials to run. Refer to optuna documentation for further details on the handling of None.

  • timeout – Stop study after the given number of second(s). Refer to optuna documentation for further details.

Returns:

The result of the hyperparameter tuning process.

Return type:

optuna.Study

Note

The keyword arguments provided in solve_kwargs will be forwarded to the call to .solve() during tuning. You can use solve_kwargs to set parameters like atol or rtol.

Notice that tune() returns an optuna.Study object. This object contains the results of the hyperparameter optimization process. Of course, you can extract the results manually, but for convenience we provide a from_study() instance method to initialize a new algorithm instance.

abstract Algorithm.from_study(study: Study) Self

Initialize an algorithm instance with tuned hyperparameters.

Examples

>>> from mfglib.alg import PriorDescent
>>> from mfglib.env import Environment
>>> from mfglib.tuning import GeometricMean
>>>
>>> prior_descent = PriorDescent(eta=0.1, n_inner=50)
>>> study = prior_descent.tune(
...     metric=GeometricMean(),
...     envs=[Environment.random_linear(T=4, n=3, m=4.0)],
... )
>>> prior_descent_tuned = prior_descent.from_study(study)

Metrics

MFGLib ships with two built-in metrics and allows for user-defined metrics too.

Built-In

class FailureRate(fail_thresh: float | int | None = None, stat: Literal['iter', 'rt', 'expl'] = 'expl')

The failure rate metric tracks how many “failed” instances were encountered. A call to solve() is considered a failure depending on the value of stat.

  • If stat="iter" then a call to solve() is considered a failure if the number of iterations reaches fail_thresh.

  • If stat="rt" then a call to solve() is considered a failure if the runtime of the call (in seconds) reaches fail_thresh.

  • If stat="expl" then a call to solve() is considered a failure if the exploitability score reaches fail_thresh.

Parameters:
  • fail_thresh – The failure threshold used to determine a “failed” instance. Can only be None when stat="iter" or stat="expl", in which case the solver’s max_iter or atol, respectively, is used as the default threshold.

  • stat – The statistic to monitor during optimization.

class GeometricMean(shift: float = 0, stat: Literal['iter', 'rt', 'expl'] = 'expl')
Parameters:
  • shift – An additional shift value for the geometric mean. Defaults to zero.

  • stat – The statistic to be used in the mean.

User-Defined

To implement a user-defined method, you simply need to create an object which implements an evaluate() method with the following signature.

class Metric.evaluate(self, pis: list[list[Tensor]], expls: list[list[float]], rts: list[list[float]], solve_kwargs: SolveKwargs)

Compute scalar measure of the given solution trace.

Parameters:
  • pis (list[list[torch.Tensor]]) – A list-of-list of policies. Each item in the outer list represents an (environment, policy) pair, and each item in the inner list corresponds with an algorithm iteration.

  • expls (list[list[float]]) – A list-of-list of exploitability scores. Each item in the outer list represents an (environment, policy) pair, and each item in the inner list corresponds with an algorithm iteration.

  • rts (list[list[float]]) – A list-of-list of runtimes. Each item in the outer list represents an (environment, policy) pair, and each item in the inner list corresponds with an algorithm iteration.

  • solve_kwargs (mfglib.alg.abc.SolveKwargs) – Additional arguments passed to solve().

Let’s implement a simple metric – the highest exploitability encountered across all environment/policy pairs. We’ll call it the Linfty metric, since it roughly corresponds to the \(\ell_\infty\) norm.

from mfglib.alg.abc import SolveKwargs

class Linfty:
    def evaluate(
        pis: list[list[torch.Tensor]],
        expls: list[list[float]],
        rts: list[list[float]],
        solve_kwargs: SolveKwargs
    ) -> float:
        max_expl = 0
        for env_expl_list in expls:
            max_expl_for_env = max(env_expl_list)
            max_expl = max(max_expl, max_expl_for_env)
        return max_expl