categorical
RL4CRN.distributions.categorical
Distribution utilities for joint categorical variables.
This module provides:
_mixed_radix, a helper to build mixed-radix multipliers for encoding multi-dimensional categorical indices into a single flattened index.MultiVariateCategorical, a joint categorical distribution over \(M\) discrete variables represented internally as a singletorch.distributions.Categoricalover \(\prod_i K_i\) outcomes.
The joint outcome corresponding to per-dimension indices \(\mathbf{i} = (i_0, \dots, i_{M-1})\) is encoded into a flat index:
where \(r_m\) are the mixed-radix multipliers:
Decoding reverses this mapping using integer division and modulo operations.
Notes
The distribution can operate in two modes:
- Index mode via
arities=[K1,...,KM]: samples are integer indices in0..Ki-1for each dimension. - Value mode via
values=[v1,...,vM]: samples are explicit numeric categories, where eachvmis a strictly increasing 1D tensor of lengthKm. Sampling returns the corresponding values rather than indices.
MultiVariateCategorical
Bases: Distribution
Joint categorical distribution over multiple discrete variables.
This distribution represents a joint categorical distribution over \(M\) discrete variables by flattening the joint support of size \(\prod_{m=0}^{M-1} K_m\) into a single categorical random variable.
Two ways to define the per-dimension categories:
1. **Index mode**: pass `arities=[K1, ..., KM]`. Each variable takes
values in `{0, ..., K_m-1}` and samples are returned as integer
indices of shape `(..., M)`.
2. **Value mode**: pass `values=[v1, ..., vM]`, where each `vm` is a
strictly increasing 1D tensor of length `K_m`. Samples are returned
as values from these tensors (shape `(..., M)`).
Shapes
- Batch shape is inherited from
logits/probs(everything except last dim). - Event shape is
(M,).
Validation
probsmust be a simplex along the last dimension.logitscan be any real vector along the last dimension.supportis marked asconstraints.dependentbecause exact support checking is non-trivial when using explicitvalues. Shape validation still applies.
Examples:
Index mode:
>>> dist = MultiVariateCategorical(arities=[2, 3], logits=torch.zeros(6))
>>> x = dist.sample((5,)) # shape (5, 2), entries in [0..1] and [0..2]
Value mode:
>>> vals0 = torch.tensor([10, 20])
>>> vals1 = torch.tensor([0.1, 0.2, 0.3])
>>> dist = MultiVariateCategorical(values=[vals0, vals1], probs=torch.ones(6)/6)
>>> x = dist.sample() # shape (2,), values from vals0 and vals1
event_shape
property
Shape of a single draw from the distribution (always (M,)).
batch_shape
property
Batch shape of the distribution (inherited from base categorical).
dtype
property
Data type of samples (values dtype in value mode, else torch.long).
probs
property
Probabilities of the flattened joint distribution (shape (..., Ktot)).
logits
property
Logits of the flattened joint distribution (shape (..., Ktot)).
support
property
Support constraint.
Marked as constraints.dependent because exact element-wise support checks
are non-trivial when each dimension can have explicit value sets.
__init__(*, arities=None, values=None, logits=None, probs=None, validate_args=None)
Create a multivariate joint categorical distribution.
| PARAMETER | DESCRIPTION |
|---|---|
arities
|
1D sequence/tensor
DEFAULT:
|
values
|
Optional list of 1D tensors specifying explicit categories
per dimension. Each tensor must be non-empty and strictly
increasing. Mutually exclusive with
DEFAULT:
|
logits
|
Logits for the flattened joint categorical distribution.
Must have last dimension size
DEFAULT:
|
probs
|
Probabilities for the flattened joint categorical distribution.
Must have last dimension size
DEFAULT:
|
validate_args
|
Passed to
DEFAULT:
|
| RAISES | DESCRIPTION |
|---|---|
AssertionError
|
If not exactly one of |
ValueError
|
If |
sample(sample_shape=torch.Size())
Draw samples.
| PARAMETER | DESCRIPTION |
|---|---|
sample_shape
|
Optional leading sample shape.
DEFAULT:
|
| RETURNS | DESCRIPTION |
|---|---|
|
Samples of shape
|
log_prob(value)
Compute log-probability of a batch of samples.
| PARAMETER | DESCRIPTION |
|---|---|
value
|
Samples of shape
|
| RETURNS | DESCRIPTION |
|---|---|
|
Tensor of log-probabilities with shape |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
In value mode, if any entry is not contained in the corresponding category set. |
joint_table()
Return the joint probability table reshaped to per-dimension axes.
| RETURNS | DESCRIPTION |
|---|---|
|
Tensor of shape |
|
|
joint probability table. |