Skip to content

categorical

RL4CRN.distributions.categorical

Distribution utilities for joint categorical variables.

This module provides:

  • _mixed_radix, a helper to build mixed-radix multipliers for encoding multi-dimensional categorical indices into a single flattened index.
  • MultiVariateCategorical, a joint categorical distribution over \(M\) discrete variables represented internally as a single torch.distributions.Categorical over \(\prod_i K_i\) outcomes.

The joint outcome corresponding to per-dimension indices \(\mathbf{i} = (i_0, \dots, i_{M-1})\) is encoded into a flat index:

\[z = \sum_{m=0}^{M-1} i_m \cdot r_m,\]

where \(r_m\) are the mixed-radix multipliers:

\[ r_m = \prod_{j=m+1}^{M-1} K_j.\]

Decoding reverses this mapping using integer division and modulo operations.

Notes

The distribution can operate in two modes:

  1. Index mode via arities=[K1,...,KM]: samples are integer indices in 0..Ki-1 for each dimension.
  2. Value mode via values=[v1,...,vM]: samples are explicit numeric categories, where each vm is a strictly increasing 1D tensor of length Km. Sampling returns the corresponding values rather than indices.

MultiVariateCategorical

Bases: Distribution

Joint categorical distribution over multiple discrete variables.

This distribution represents a joint categorical distribution over \(M\) discrete variables by flattening the joint support of size \(\prod_{m=0}^{M-1} K_m\) into a single categorical random variable.

Two ways to define the per-dimension categories:

1. **Index mode**: pass `arities=[K1, ..., KM]`. Each variable takes
   values in `{0, ..., K_m-1}` and samples are returned as integer
   indices of shape `(..., M)`.
2. **Value mode**: pass `values=[v1, ..., vM]`, where each `vm` is a
   strictly increasing 1D tensor of length `K_m`. Samples are returned
   as values from these tensors (shape `(..., M)`).
Shapes
  • Batch shape is inherited from logits/probs (everything except last dim).
  • Event shape is (M,).
Validation
  • probs must be a simplex along the last dimension.
  • logits can be any real vector along the last dimension.
  • support is marked as constraints.dependent because exact support checking is non-trivial when using explicit values. Shape validation still applies.

Examples:

Index mode:

>>> dist = MultiVariateCategorical(arities=[2, 3], logits=torch.zeros(6))
>>> x = dist.sample((5,))  # shape (5, 2), entries in [0..1] and [0..2]

Value mode:

>>> vals0 = torch.tensor([10, 20])
>>> vals1 = torch.tensor([0.1, 0.2, 0.3])
>>> dist = MultiVariateCategorical(values=[vals0, vals1], probs=torch.ones(6)/6)
>>> x = dist.sample()  # shape (2,), values from vals0 and vals1

event_shape property

Shape of a single draw from the distribution (always (M,)).

batch_shape property

Batch shape of the distribution (inherited from base categorical).

dtype property

Data type of samples (values dtype in value mode, else torch.long).

probs property

Probabilities of the flattened joint distribution (shape (..., Ktot)).

logits property

Logits of the flattened joint distribution (shape (..., Ktot)).

support property

Support constraint.

Marked as constraints.dependent because exact element-wise support checks are non-trivial when each dimension can have explicit value sets.

__init__(*, arities=None, values=None, logits=None, probs=None, validate_args=None)

Create a multivariate joint categorical distribution.

PARAMETER DESCRIPTION
arities

1D sequence/tensor (K1,...,KM) specifying the number of categories per dimension. Used in index mode. Mutually exclusive with values.

DEFAULT: None

values

Optional list of 1D tensors specifying explicit categories per dimension. Each tensor must be non-empty and strictly increasing. Mutually exclusive with arities.

DEFAULT: None

logits

Logits for the flattened joint categorical distribution. Must have last dimension size prod(arities). Mutually exclusive with probs.

DEFAULT: None

probs

Probabilities for the flattened joint categorical distribution. Must have last dimension size prod(arities) and be a simplex. Mutually exclusive with logits.

DEFAULT: None

validate_args

Passed to torch.distributions.Distribution to enable/disable argument validation.

DEFAULT: None

RAISES DESCRIPTION
AssertionError

If not exactly one of logits/probs is provided, or if arities/values are missing/invalid, or if the last dimension of logits/probs does not match the joint cardinality.

ValueError

If values tensors are not 1D, empty, or not strictly increasing.

sample(sample_shape=torch.Size())

Draw samples.

PARAMETER DESCRIPTION
sample_shape

Optional leading sample shape.

DEFAULT: Size()

RETURNS DESCRIPTION

Samples of shape sample_shape + batch_shape + (M,).

  • In index mode (arities provided): integer indices in [0, K_m-1].
  • In value mode (values provided): values from the provided category tensors per dimension.

log_prob(value)

Compute log-probability of a batch of samples.

PARAMETER DESCRIPTION
value

Samples of shape (..., M). In index mode these should be integer indices; in value mode they should match the explicit category values.

RETURNS DESCRIPTION

Tensor of log-probabilities with shape value.shape[:-1].

RAISES DESCRIPTION
ValueError

In value mode, if any entry is not contained in the corresponding category set.

joint_table()

Return the joint probability table reshaped to per-dimension axes.

RETURNS DESCRIPTION

Tensor of shape batch_shape + (K1, ..., KM) representing the full

joint probability table.