Model Selection¶
select_n_components ¶
select_n_components(x: Matrix, *, mask: Matrix | None = None, components: Iterable[int] | None = None, config: SelectionConfig | None = None, **opts: object) -> tuple[int, dict[str, object], list[dict[str, object]], VBPCA | None]
Select n_components by sweeping candidates and tracking end metrics.
Args:
x: Data matrix (dense or sparse).
mask: Optional boolean mask with the same shape as x.
components: Candidate component counts. Defaults to
1..min(n_features, n_samples).
config: Selection parameters controlling metric, stopping behavior,
patience, trials, and whether to compute explained variance or
retain the best model.
**opts: Additional options forwarded to the VBPCA constructor and fit.
Returns:
Tuple (best_k, best_metrics, trace, best_model) where:
- best_k: chosen component count.
- best_metrics: scalar metrics for the best candidate.
- trace: list of per-k endpoint metrics.
- best_model: the best VBPCA instance, or None if not requested.
Raises:
ValueError: If metric is invalid or no valid components are provided.
cross_validate_components ¶
cross_validate_components(x: Matrix, *, mask: Matrix | None = None, components: Iterable[int] | None = None, config: CVConfig | None = None, **opts: object) -> tuple[int, list[dict[str, object]]]
K-fold cross-validated model selection for VBPCA.
Partitions observed entries into n_splits folds. For each fold
the held-out entries become an xprobe set. All candidate k values
are evaluated on every fold via select_n_components. The final
k is chosen by the 1-SE rule: the smallest k whose mean
metric across folds is within one standard error of the global
minimum.
All tracked metrics (rms, prms, cost) are recorded per fold regardless of which metric is used for selection, so callers can compare selection criteria without re-running.
Args:
x: Data matrix (dense or sparse), shape
(n_features, n_samples).
mask: Optional boolean mask with the same shape as x.
components: Candidate component counts. Defaults to
1 .. min(n_features, n_samples).
config: Cross-validation parameters. Uses CVConfig()
defaults when None.
**opts: Additional options forwarded to select_n_components
and ultimately to the VBPCA constructor / fit.
Returns:
Tuple (best_k, cv_results) where:
- ``best_k``: selected component count.
- ``cv_results``: list of dicts (one per candidate *k*) with keys
``k``, ``mean_<m>``, ``std_<m>``, ``se_<m>`` for each metric,
and ``<m>_fold_<i>`` per-fold values.
Raises:
ValueError: If metric is invalid, no valid components
are provided, or n_splits < 2.
Example: >>> best_k, cv = cross_validate_components( ... X, components=range(1, 6), config=CVConfig(n_splits=5) ... )
SelectionConfig
dataclass
¶
SelectionConfig(metric: _Metric = 'prms', stop_on_metric_reversal: bool = False, patience: int | None = None, max_trials: int | None = None, compute_explained_variance: bool = True, return_best_model: bool = False)
Configuration for component selection.
CVConfig
dataclass
¶
Configuration for K-fold cross-validated component selection.
Attributes:
metric: Selection metric ("prms" or "cost").
n_splits: Number of cross-validation folds.
one_se_rule: If True, select the smallest k whose mean metric
is within one standard error of the global minimum.
seed: Random seed for fold partitioning and model fitting.