Skip to content

Gmm estimator

rex.gmm_estimator.GMMEstimator ¤

__init__(data: jax.typing.ArrayLike, name: str = 'GMM', threshold: float = 1e-07, verbose: bool = True) ¤

Gaussian Mixture Model Estimator.

Parameters:

  • data (ArrayLike) –

    1D array of delay data.

  • name (str, default: 'GMM' ) –

    Name of the model.

  • threshold (float, default: 1e-07 ) –

    Threshold for determining if the data is deterministic.

  • verbose (bool, default: True ) –

    Whether to print progress.

fit(num_steps: int = 100, num_components: int = 2, step_size: float = 0.05, seed: int = 0) ¤

Fit the model to the data.

Parameters:

  • num_steps (int, default: 100 ) –

    Number of steps to train the model.

  • num_components (int, default: 2 ) –

    Number of components in the mixture model.

  • step_size (float, default: 0.05 ) –

    Step size for the optimizer.

  • seed (int, default: 0 ) –

    Random seed.

get_dist(percentile: float = 0.99) -> base.StaticDist ¤

Get the distribution.

Parameters:

  • percentile (float, default: 0.99 ) –

    A percentile to prune the number of components that do not contribute much.

Returns:

  • StaticDist

    base.StaticDist: The distribution object.

plot_hist(ax: plt.Axes = None, edgecolor: str = None, facecolor: str = None, bins: int = 100, xmin: float = None, xmax: float = None, num_points: int = 1000, plot_dist: bool = True) -> plt.Axes ¤

Plot the histogram of the data and the fitted distribution.

Parameters:

  • ax (Axes, default: None ) –

    Axes to plot on.

  • edgecolor (str, default: None ) –

    Edge color of the histogram.

  • facecolor (str, default: None ) –

    Face color of the histogram.

  • bins (int, default: 100 ) –

    Number of bins for the histogram.

  • xmin (float, default: None ) –

    Minimum x value for the histogram. Can be used to avoid outliers.

  • xmax (float, default: None ) –

    Maximum x value for the histogram. Can be used to avoid outliers.

  • num_points (int, default: 1000 ) –

    Number of points to plot the distribution.

  • plot_dist (bool, default: True ) –

    Whether to plot the fitted distribution.

Returns:

  • Axes

    The axes with the plot.

plot_loss(ax: plt.Axes = None, edgecolor: str = None) -> plt.Axes ¤

Plot the loss function.

Parameters:

  • ax (Axes, default: None ) –

    Axes to plot on.

  • edgecolor (str, default: None ) –

    Edge color of the plot.

Returns:

  • Axes

    plt.Axes: The axes with the plot.

plot_normalized_weights(ax: plt.Axes = None, edgecolor: str = None) -> plt.Axes ¤

Plot the normalized weights.

Parameters:

  • ax (Axes, default: None ) –

    Axes to plot on.

  • edgecolor (str, default: None ) –

    Edge color of the plot.

Returns:

  • Axes

    The axes with the plot.

animate_training(num_frames: int = 30, fig: plt.Figure = None, ax: plt.Axes = None, edgecolor: str = None, facecolor: str = None, bins: int = 40, xmin: float = None, xmax: float = None, num_points: int = 1000) -> matplotlib.animation.FuncAnimation ¤

Animate the training process.

Parameters:

  • num_frames (int, default: 30 ) –

    Number of frames to animate.

  • fig (Figure, default: None ) –

    Figure to plot on.

  • ax (Axes, default: None ) –

    Axes to plot on.

  • edgecolor (str, default: None ) –

    Edge color of the histogram.

  • facecolor (str, default: None ) –

    Face color of the histogram.

  • bins (int, default: 40 ) –

    Number of bins for the histogram.

  • xmin (float, default: None ) –

    Minimum x value for the histogram. Can be used to avoid outliers.

  • xmax (float, default: None ) –

    Maximum x value for the histogram. Can be used to avoid outliers.

  • num_points (int, default: 1000 ) –

    Number of points to plot the distribution.

Returns:

  • FuncAnimation

    matplotlib.animation.FuncAnimation: The animation object.