Skip to content

parameter_generator_from_distribution

RL4CRN.policies.parameter_generator_from_distribution

RL4CRN.policies.parameter_generator_from_distribution

Neural parameter generator used by RL policies to sample reaction parameters from learned probability distributions.

This module defines ParameterGeneratorFromDistribution, a small wrapper around a feed-forward backbone (FFNN) that maps a conditioning embedding (e.g., an encoded IOCRN state and/or a reaction choice) to the parameters of a distribution, then provides:

  • sampling of parameter vectors (continuous and/or discrete),
  • log-probability evaluation of provided samples (for policy-gradient objectives),
  • entropy computation (for exploration/regularization).

Supported families include several LogNormal variants (1D, independent multivariate, and a correlated multivariate LogNormal via an exponentiated MultivariateNormal), a correlated MultivariateNormal, and multivariate categorical distributions. Optional masks allow a single fixed-size generator to handle reactions with different effective parameter dimensionalities and to invalidate impossible categorical choices.

ParameterGeneratorFromDistribution

Bases: Module

Neural module that parameterizes and samples reaction-parameter vectors from a chosen probability distribution, conditioned on a learned embedding.

This class is used by policies to generate continuous and/or discrete reaction parameters. A small feed-forward backbone (FFNN) maps an input embedding x (typically the encoded IOCRN state concatenated with a one-hot reaction choice) to the parameters of a distribution. The module then:

  1. constructs the corresponding distribution object,
  2. samples parameters (or evaluates the log-probability of provided samples),
  3. returns samples, summed log-probabilities, and summed entropies per batch element.

Supported distributions (selected via distribution["type"]):

Continuous
  • "lognormal_1D": One-dimensional LogNormal with learned mean/std (in mean/std space, converted to underlying Normal parameters mu, sigma).
  • "lognormal_independent": Factorized (independent) LogNormal across D dimensions. The backbone outputs per-dimension mean/std (in mean/std space), which are converted to mu, sigma of the underlying Normal.
  • "lognormal_processed": Correlated multivariate LogNormal built as an exponentiated MultivariateNormal. The backbone outputs a bounded mean vector mu and a Cholesky factor L for the covariance in log-space, enabling correlations and ensuring PSD covariance. Supports variable active dimensionality per batch element via an optional mask.
  • "multivariate_normal": Correlated multivariate Normal with learned mean and Cholesky factor L. (Mask currently accepted but not yet used to reduce dimensionality.)
Discrete
  • "categorical": Multivariate categorical distribution over a fixed set of categories (shared across dimensions). Implemented via MultiVariateCategorical. Supports:

    • logit_mask to invalidate some category combinations,
    • dimension_mask to zero-out unused discrete dimensions.
Masking semantics

Masks allow parameter vectors of different effective sizes to coexist in a fixed-size tensor. Depending on the distribution:

  • For categorical: dimension_mask zeros out inactive dimensions; logit_mask can set invalid logits to -inf.
  • For processed lognormal: mask (shape (N, D)) indicates the number of active dimensions per batch element; the implementation groups samples by active dimension count and builds smaller distributions for efficiency and numerical stability.
Inputs-outputs

The module's forward signature depends on the chosen distribution type, but the returned values are consistent: - samples: tensor of sampled parameters (shape (N, D) or (N, 1) for 1D), - log_probs: tensor of total log-probabilities per batch element (shape (N,)), - entropies: tensor of total entropies per batch element (shape (N,)).

Notes
  • The constructor dynamically assigns self.forward to a distribution-specific implementation. This keeps call-sites uniform while allowing different argument sets (e.g., mask vs logit_mask/dimension_mask).
  • For LogNormal parameterization in 1D/independent modes, the backbone outputs positive mean/std via softplus and converts them to underlying Normal parameters: sigma = sqrt(log(1 + std^2 / mean^2)), mu = log(mean) - 0.5*sigma^2 which ensures a valid LogNormal with the requested mean/std.

__init__(distribution, backbone_attributes, device='cpu')

Construct a parameter generator for a specified distribution family.

PARAMETER DESCRIPTION
distribution

dict Distribution specification. Must include:

  • "type": one of {"lognormal_1D", "lognormal_independent", "lognormal_processed", "multivariate_normal", "categorical"}. Additional required keys depend on the type:

  • "dim" (int): required for all multivariate types and categorical.

  • "categories" (torch.Tensor): required for "categorical"; 1D tensor of category values shared across dimensions.

Optional keys for "lognormal_processed":

  • squash (float): scaling before tanh for mean bounding.
  • mu_max (float): max absolute value for bounded log-space mean.
  • sigma_min (float): lower bound added to Cholesky diagonal.
  • sigma_max (float or None): optional upper bound on Cholesky diagonal.
  • off_scale (float or None): if set, off-diagonals are bounded by off_scale * tanh(off_raw).

backbone_attributes

dict FFNN configuration with keys:

  • "input_size" (int): embedding dimension expected by the generator,
  • "hidden_size" (int): FFNN hidden width,
  • "num_layers" (int): number of FFNN layers. The backbone output size is determined by the distribution type (e.g., 2*D for independent lognormal, D + D + D(D-1)/2 for Cholesky-parameterized covariances).

str or torch.device, default="cpu"

Device where parameters, backbone, and intermediate tensors are allocated.