parameter_generator_from_distribution
RL4CRN.policies.parameter_generator_from_distribution
RL4CRN.policies.parameter_generator_from_distribution
Neural parameter generator used by RL policies to sample reaction parameters from learned probability distributions.
This module defines ParameterGeneratorFromDistribution, a small wrapper around a
feed-forward backbone (FFNN) that maps a conditioning embedding (e.g., an encoded IOCRN
state and/or a reaction choice) to the parameters of a distribution, then provides:
- sampling of parameter vectors (continuous and/or discrete),
- log-probability evaluation of provided samples (for policy-gradient objectives),
- entropy computation (for exploration/regularization).
Supported families include several LogNormal variants (1D, independent multivariate, and a correlated multivariate LogNormal via an exponentiated MultivariateNormal), a correlated MultivariateNormal, and multivariate categorical distributions. Optional masks allow a single fixed-size generator to handle reactions with different effective parameter dimensionalities and to invalidate impossible categorical choices.
ParameterGeneratorFromDistribution
Bases: Module
Neural module that parameterizes and samples reaction-parameter vectors from a chosen probability distribution, conditioned on a learned embedding.
This class is used by policies to generate continuous and/or discrete reaction parameters.
A small feed-forward backbone (FFNN) maps an input embedding x (typically the encoded
IOCRN state concatenated with a one-hot reaction choice) to the parameters of a distribution.
The module then:
- constructs the corresponding distribution object,
- samples parameters (or evaluates the log-probability of provided samples),
- returns samples, summed log-probabilities, and summed entropies per batch element.
Supported distributions (selected via distribution["type"]):
Continuous
"lognormal_1D": One-dimensional LogNormal with learned mean/std (in mean/std space, converted to underlying Normal parametersmu, sigma)."lognormal_independent": Factorized (independent) LogNormal acrossDdimensions. The backbone outputs per-dimension mean/std (in mean/std space), which are converted tomu, sigmaof the underlying Normal."lognormal_processed": Correlated multivariate LogNormal built as an exponentiated MultivariateNormal. The backbone outputs a bounded mean vectormuand a Cholesky factorLfor the covariance in log-space, enabling correlations and ensuring PSD covariance. Supports variable active dimensionality per batch element via an optional mask."multivariate_normal": Correlated multivariate Normal with learned mean and Cholesky factorL. (Mask currently accepted but not yet used to reduce dimensionality.)
Discrete
-
"categorical": Multivariate categorical distribution over a fixed set of categories (shared across dimensions). Implemented viaMultiVariateCategorical. Supports:logit_maskto invalidate some category combinations,dimension_maskto zero-out unused discrete dimensions.
Masking semantics
Masks allow parameter vectors of different effective sizes to coexist in a fixed-size tensor. Depending on the distribution:
- For categorical:
dimension_maskzeros out inactive dimensions;logit_maskcan set invalid logits to-inf. - For processed lognormal:
mask(shape(N, D)) indicates the number of active dimensions per batch element; the implementation groups samples by active dimension count and builds smaller distributions for efficiency and numerical stability.
Inputs-outputs
The module's forward signature depends on the chosen distribution type, but the returned
values are consistent:
- samples: tensor of sampled parameters (shape (N, D) or (N, 1) for 1D),
- log_probs: tensor of total log-probabilities per batch element (shape (N,)),
- entropies: tensor of total entropies per batch element (shape (N,)).
Notes
- The constructor dynamically assigns
self.forwardto a distribution-specific implementation. This keeps call-sites uniform while allowing different argument sets (e.g.,maskvslogit_mask/dimension_mask). - For LogNormal parameterization in 1D/independent modes, the backbone outputs positive
mean/std via
softplusand converts them to underlying Normal parameters: sigma = sqrt(log(1 + std^2 / mean^2)), mu = log(mean) - 0.5*sigma^2 which ensures a valid LogNormal with the requested mean/std.
__init__(distribution, backbone_attributes, device='cpu')
Construct a parameter generator for a specified distribution family.
| PARAMETER | DESCRIPTION |
|---|---|
distribution
|
dict Distribution specification. Must include:
Optional keys for
|
backbone_attributes
|
dict FFNN configuration with keys:
|
str or torch.device, default="cpu"
Device where parameters, backbone, and intermediate tensors are allocated.