stochastic
RL4CRN.rewards.stochastic
Stochastic rewards (SSA).
This module defines reward / loss functions that evaluate IOCRN controllers under
intrinsic stochasticity by running Stochastic Simulation Algorithm (SSA) rollouts
via IOCRN.transient_response_SSA. The rewards are computed primarily from the
mean output trajectories, with optional robustness terms that penalize variability
(e.g., steady-state coefficient of variation), and they populate crn.last_task_info
with metadata for downstream logging and analysis.
dynamic_tracking_error_SSA(crn, u_list, x0_list, time_horizon, r_list, w, n_trajectories=100, max_threads=10000, norm=1, relative=False, LARGE_NUMBER=10000.0, LARGE_PENALTY=10000.0)
Compute a dynamic tracking cost using stochastic simulation (SSA).
This function evaluates tracking performance under intrinsic noise by running
multiple SSA trajectories per (input, initial-condition) scenario and comparing
the mean output trajectory against the reference r_list using
performance_metric.
The cost is computed on the mean trajectories only (variance is ignored here). If the SSA simulator reports divergence, a large constant penalty is returned.
| PARAMETER | DESCRIPTION |
|---|---|
crn
|
IOCRN
IOCRN-like object implementing
|
u_list
|
list[np.ndarray]
List of constant input vectors, each of shape
|
x0_list
|
list[np.ndarray]
List of initial state vectors, each of shape
|
time_horizon
|
np.ndarray 1D time grid (or a simulator-specific time specification) over which the SSA trajectories are sampled.
|
r_list
|
list[np.ndarray]
List of reference targets per scenario, in the same format expected by
|
w
|
np.ndarray
Weights for the tracking metric. Typically shape
|
n_trajectories
|
int, default=100 Number of SSA trajectories simulated per scenario.
|
max_threads
|
int, default=10000 Upper bound on GPU threads / parallelism for SSA (passed through to the CRN).
|
norm
|
int, default=1
Norm used by
|
relative
|
bool, default=False
If True, compute a relative error (as supported by
|
LARGE_NUMBER
|
float, default=1e4 Maximum value / divergence threshold passed to the SSA simulator.
|
LARGE_PENALTY
|
float, default=1e4 Returned cost when the simulator indicates divergence.
|
| RETURNS | DESCRIPTION |
|---|---|
performance
|
float Scalar tracking cost (lower is better). |
last_task_info
|
dict
Updated |
robust_tracking_loss_SSA(crn, u_list, x0_list, time_horizon, r_list, w, n_trajectories=100, max_threads=10000, norm=2, relative=False, LARGE_NUMBER=10000.0, LARGE_PENALTY=10000.0, lambda_std=0.5, rpa_weight=1.0, cv_weight=1.0)
Compute a robustness-aware tracking loss under SSA.
This loss combines:
- an accuracy term (tracking error on the mean trajectory), and
- a precision term (steady-state coefficient-of-variation penalty).
The accuracy term is computed via performance_metric(r_list, y_mean_list, w, ...).
The precision term uses the coefficient of variation (CV = std / |mean|) computed
from the SSA output standard deviation and mean in time regions where the weight
vector w is positive (interpreted here as the “steady-state” window).
The final loss is:
loss = rpa_weight * base_error + cv_weight * (lambda_std * mean_cv)
Notes
- This implementation assumes a single-output layout sometimes produced as
(B, 1, T); if so, it squeezes the singleton dimension to(B, T). - If
crn.last_task_info['has_diverged']is True, the loss components are overridden with a large penalty / safe defaults.
| PARAMETER | DESCRIPTION |
|---|---|
crn
|
IOCRN
IOCRN-like object implementing
|
u_list
|
list[np.ndarray]
List of constant input vectors, each of shape
|
x0_list
|
list[np.ndarray]
List of initial state vectors, each of shape
|
time_horizon
|
np.ndarray Time grid passed to SSA.
|
r_list
|
list[np.ndarray]
Reference targets per scenario (format expected by
|
w
|
np.ndarray
Weights used for the tracking error and to identify the steady-state window.
For the CV penalty,
|
n_trajectories
|
int, default=100 Number of SSA trajectories per scenario.
|
max_threads
|
int, default=10000 Upper bound on GPU threads / parallelism for SSA.
|
norm
|
int, default=2
Norm used by
|
relative
|
bool, default=False
If True, compute a relative tracking error (as supported by
|
LARGE_NUMBER
|
float, default=1e4 Maximum value / divergence threshold passed to the SSA simulator.
|
LARGE_PENALTY
|
float, default=1e4 Penalty used when divergence is detected.
|
lambda_std
|
float, default=0.5 Scaling factor applied to the mean CV penalty.
|
rpa_weight
|
float, default=1.0 Weight multiplying the accuracy (mean-trajectory tracking) term.
|
cv_weight
|
float, default=1.0 Weight multiplying the precision (CV) term.
|
| RETURNS | DESCRIPTION |
|---|---|
performance
|
float Scalar robustness-aware loss (lower is better). |
last_task_info
|
dict
Updated
|