input_interface
RL4CRN.utils.input_interface
User-facing input interface utilities for GenAI-Net / RL4CRN tutorials.
This module provides: - lightweight configuration objects with sensible defaults - a configurator to apply presets and overrides - a session builder that wires together task/template/library/env/interfaces/policy/agent - a trainer that supports chunked training, early stopping (Ctrl+C), and save/load checkpoints
The goal is to make tutorial notebooks trivial to run, while keeping all advanced knobs discoverable via config inspection.
TaskSpec
dataclass
Fully materialized task description used by environments.
| ATTRIBUTE | DESCRIPTION |
|---|---|
template_crn |
Compiled IOCRN template.
TYPE:
|
library_components |
Tuple (library, M, K, masks).
TYPE:
|
species_labels |
Species labels used by the task.
TYPE:
|
kind |
Task kind (e.g., "logic", "tracking", "oscillator_mean", ...).
TYPE:
|
t_f |
Final simulation time.
TYPE:
|
n_t |
Number of time points.
TYPE:
|
time_horizon |
1D array of time points (float32).
TYPE:
|
n_inputs |
Number of input channels.
TYPE:
|
u_values |
Values for grid tasks (tracking/oscillator/SSA).
TYPE:
|
dose_range |
(u_min, u_max, n) for "dose_response".
TYPE:
|
u_spec |
Optional input generation spec.
TYPE:
|
u_list |
List of input vectors (each shape (p,), float32).
TYPE:
|
ic_spec |
IC specification used to build the IC object.
TYPE:
|
ic |
RL4CRN IC object.
TYPE:
|
weights_spec |
Weight spec used to build the weight matrix (when applicable).
TYPE:
|
weights |
Weight matrix (when applicable).
TYPE:
|
target |
Target spec for tracking/SSA tasks.
TYPE:
|
logic_fn |
Boolean logic function for "logic".
TYPE:
|
target_fn |
Target function for dose response.
TYPE:
|
osc_w |
Oscillation error weights.
TYPE:
|
t0 |
Oscillation error start time.
TYPE:
|
n_trajectories |
SSA number of trajectories.
TYPE:
|
max_threads |
SSA max threads.
TYPE:
|
cv_weight |
Robust SSA CV weight.
TYPE:
|
rpa_weight |
Robust SSA RPA weight.
TYPE:
|
relative |
Whether to use relative error in SSA rewards.
TYPE:
|
norm |
Norm used in tracking losses.
TYPE:
|
LARGE_NUMBER |
Large penalty scalar used by deterministic rewards.
TYPE:
|
LARGE_PENALTY |
Large penalty scalar used by SSA rewards (when applicable).
TYPE:
|
compute_reward |
Reward callable built from this TaskSpec.
TYPE:
|
params |
Task-kind-specific parameters (forward-compatible extension point).
TYPE:
|
SolverCfg
dataclass
Solver configuration.
| ATTRIBUTE | DESCRIPTION |
|---|---|
algorithm |
Solver name (e.g., "CVODE" or "LSODA").
TYPE:
|
rtol |
Relative tolerance.
TYPE:
|
atol |
Absolute tolerance.
TYPE:
|
TrainCfg
dataclass
Training configuration.
| ATTRIBUTE | DESCRIPTION |
|---|---|
epochs |
Total number of epochs (you may run in chunks).
TYPE:
|
max_added_reactions |
Episode length: number of reaction-addition steps.
TYPE:
|
render_every |
Print progress every N epochs (0 disables).
TYPE:
|
hall_of_fame_size |
Hall-of-fame capacity in ParallelEnvironments.
TYPE:
|
batch_multiplier |
Batch size = batch_multiplier * num_cpus (if batch_size is None).
TYPE:
|
seed |
Random seed for reproducibility.
TYPE:
|
n_cpus |
CPU count to use. If None, uses os.cpu_count().
TYPE:
|
batch_size |
If provided, overrides auto batch sizing.
TYPE:
|
PolicyCfg
dataclass
Policy network configuration.
| ATTRIBUTE | DESCRIPTION |
|---|---|
width |
Hidden size for encoder/heads.
TYPE:
|
depth |
Number of layers for encoder/heads.
TYPE:
|
deep_layer_size |
Size of deep layer block (policy-dependent).
TYPE:
|
continuous_distribution |
Dict describing continuous parameter distribution.
TYPE:
|
entropy_weights_per_head |
Entropy coefficients per head.
TYPE:
|
ordering_enabled |
If True, uses ordered reaction addition policy.
TYPE:
|
constraint_strength |
Constraint strength for ordered policy.
TYPE:
|
zero_reaction_idx |
If set, this reaction index is treated as a "no-op" action (allowing multiple re-samples).
TYPE:
|
stop_flag |
Internal flag to indicate if calling a "zero_reaction" is a stopping condition (instead of a no-op).
TYPE:
|
AgentCfg
dataclass
Agent configuration.
| ATTRIBUTE | DESCRIPTION |
|---|---|
learning_rate |
Optimizer learning rate.
TYPE:
|
entropy_scheduler |
Scheduler parameters for entropy regularization.
TYPE:
|
risk_scheduler |
Scheduler parameters for risk-sensitive objective.
TYPE:
|
sil_settings |
Self-imitation learning configuration.
TYPE:
|
RenderCfg
dataclass
Rendering configuration.
| ATTRIBUTE | DESCRIPTION |
|---|---|
n_best |
Number of top trajectories to render.
TYPE:
|
disregarded_percentage |
Percentage of trajectories to disregard based on reward (for stochastic tasks).
TYPE:
|
mode |
Rendering mode, e.g., "transients", "inputs", "final_state", etc.
TYPE:
|
Config
dataclass
Top-level configuration container.
| ATTRIBUTE | DESCRIPTION |
|---|---|
task |
Task configuration.
TYPE:
|
solver |
Solver configuration.
TYPE:
|
train |
Training configuration.
TYPE:
|
library |
Library configuration.
TYPE:
|
policy |
Policy configuration.
TYPE:
|
agent |
Agent configuration.
TYPE:
|
render |
Rendering configuration.
TYPE:
|
to_dict()
Convert config to a JSON-serializable dictionary.
| RETURNS | DESCRIPTION |
|---|---|
Dict[str, Any]
|
Nested dictionary of config values. |
describe(width=120)
Pretty-print the full configuration.
| PARAMETER | DESCRIPTION |
|---|---|
width
|
Print width for formatting.
TYPE:
|
Configurator
Helpers to create configs from presets and apply overrides.
preset(name='balanced')
staticmethod
Create a config from a named preset.
| PARAMETER | DESCRIPTION |
|---|---|
name
|
Preset name. Supported: - "fast": small networks, looser tolerances - "balanced": sensible defaults - "quality": larger networks, more capacity - "paper": settings used in the GenAI-Net paper experiments
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Config
|
Config instance. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If preset name is unknown. |
with_overrides(cfg, **overrides)
staticmethod
Return a deep-copied config with nested overrides applied.
| PARAMETER | DESCRIPTION |
|---|---|
cfg
|
Base config.
TYPE:
|
**overrides
|
Nested dictionaries keyed by top-level sections (task, solver, train, library, policy, agent).
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Config
|
New Config with overrides applied. |
Session
dataclass
Container for all objects needed to run training and inspection.
| ATTRIBUTE | DESCRIPTION |
|---|---|
cfg |
Config used to build this session.
TYPE:
|
device |
Torch device string.
TYPE:
|
n_cpus |
Number of CPUs used for parallel rollouts.
TYPE:
|
batch_size |
Number of parallel environments.
TYPE:
|
task |
Materialized TaskSpec used to compute rewards.
TYPE:
|
crn_template |
Compiled IOCRN template.
TYPE:
|
species_labels |
Species labels for template/library.
TYPE:
|
library |
Reaction library.
TYPE:
|
M |
Number of reactions in library.
TYPE:
|
K |
Number of parameters in library.
TYPE:
|
masks |
Parameter/logit masks from the library.
TYPE:
|
p |
Number of CRN input channels.
TYPE:
|
mult_env |
Parallel environments.
TYPE:
|
observer |
Env->agent observer.
TYPE:
|
tensorizer |
Observer tensorizer.
TYPE:
|
actuator |
Agent->env actuator.
TYPE:
|
stepper |
Environment stepper.
TYPE:
|
policy |
Policy instance.
TYPE:
|
agent |
Agent instance.
TYPE:
|
sample_hof |
an HallOfFame of CRNs from the hall of fame, populated after calling
TYPE:
|
from_config(cfg, task, device=None, logger=None)
staticmethod
Build a Session from a Config.
| PARAMETER | DESCRIPTION |
|---|---|
cfg
|
Configuration object.
TYPE:
|
task
|
Materialized TaskSpec object.
TYPE:
|
device
|
Torch device string. If None, auto-selects.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
'Session'
|
Initialized Session with all required RL4CRN objects wired up. |
sample(n_samples, sample_hof_size, *, u_list=None, u_spec=None, u_values=None, dose_range=None, ic=None, weights=None)
Sample CRNs from the current policy without training (evaluation-only).
This method creates a temporary batch of environments, performs one rollout
(episode) per environment using the current policy in eval mode, computes
rewards, and stores the best sampled environments in a dedicated
sample_hof HallOfFame.
Sampling does not perform any learning updates (no backpropagation).
Calling this method again replaces the previously stored sample_hof, so
that different checkpoints can store different sample sets.
| PARAMETER | DESCRIPTION |
|---|---|
n_samples
|
Number of environments to roll out (number of samples drawn).
TYPE:
|
sample_hof_size
|
Capacity of the sample HallOfFame (best K kept).
TYPE:
|
u_list
|
Optional explicit list of input vectors to evaluate.
TYPE:
|
u_spec
|
Optional input generation spec (same as build_u_list): ("custom", u_list), ("grid", values), ("linspace", u_min, u_max, n)
TYPE:
|
u_values
|
Optional enumerated values used by build_u_list for grid tasks.
TYPE:
|
dose_range
|
Optional (u_min, u_max, n) for dose_response input generation.
TYPE:
|
ic
|
Optional IC spec override (same format accepted by build_ic).
TYPE:
|
weights
|
Optional weights spec override (same format accepted by build_weights).
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
HallOfFame
|
The newly created sample HallOfFame containing sampled env snapshots.
TYPE:
|
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If n_samples/sample_hof_size are invalid or input dimension mismatch. |
TrainState
dataclass
Training state that persists across chunked runs.
| ATTRIBUTE | DESCRIPTION |
|---|---|
epoch |
Next epoch index to run.
TYPE:
|
history |
List of dicts with keys {"epoch","best","median"}.
TYPE:
|
Trainer
Chunkable trainer with stop/resume and checkpointing.
__init__(session)
Initialize trainer.
| PARAMETER | DESCRIPTION |
|---|---|
session
|
Built Session containing envs, agent, and task reward function.
TYPE:
|
resimulate(crns, *, task=None, u_list=None, u_spec=None, u_values=None, dose_range=None, ic=None, weights=None, n_cpus=None)
Clone and re-simulate CRNs under a task, optionally overriding conditions.
This is mainly for re-evaluating existing CRNs (e.g., from the training Hall of Fame) under new experimental conditions such as different input scenarios or initial conditions, without mutating the original CRN objects.
The method clones each CRN via .clone() before simulation, runs task reward evaluation
(which triggers transient simulations internally), and returns the cloned CRNs with
updated last_task_info.
| PARAMETER | DESCRIPTION |
|---|---|
crns
|
List of CRN objects to re-simulate. Each must implement
TYPE:
|
task
|
Optional TaskSpec to use. If None, defaults to
TYPE:
|
u_list
|
Optional explicit list of input vectors for evaluation.
TYPE:
|
u_spec
|
Optional input generation spec (same as
TYPE:
|
u_values
|
Optional enumerated values used by
TYPE:
|
dose_range
|
Optional (u_min, u_max, n) for dose_response input generation.
TYPE:
|
ic
|
Optional IC spec override (same format accepted by
TYPE:
|
weights
|
Optional weights spec override (same format accepted by
TYPE:
|
n_cpus
|
Optional CPU override for evaluation. If None, uses
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
List[Any]
|
List of cloned CRNs after evaluation. The returned CRNs have fresh |
List[Any]
|
corresponding to this re-simulation. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If CRNs do not support |
step_epoch()
Run a single epoch: rollout, reward eval, and policy update.
| RETURNS | DESCRIPTION |
|---|---|
Tuple[float, float]
|
Tuple (best_loss, median_loss) over the batch. |
run(epochs, checkpoint_path=None)
Run training for a chunk of epochs.
| PARAMETER | DESCRIPTION |
|---|---|
epochs
|
Number of epochs to run in this chunk.
TYPE:
|
checkpoint_path
|
If provided, saves a checkpoint periodically and on interrupt.
TYPE:
|
best_crn()
Return the best CRN currently in the hall of fame.
| RETURNS | DESCRIPTION |
|---|---|
Optional[Any]
|
Best CRN object if available, else None. |
sample(n_samples, sample_hof_size, *, u_list=None, u_spec=None, u_values=None, dose_range=None, ic=None, weights=None)
Convenience wrapper around Session.sample to sample from the current policy.
inspect(crn, *, plot=True, plot_type=None, title='CRN', **kwargs)
Print and optionally plot a given CRN.
| PARAMETER | DESCRIPTION |
|---|---|
crn
|
The CRN object to inspect.
TYPE:
|
plot
|
If True, call the appropriate plot method on the CRN (if available).
TYPE:
|
plot_type
|
Optional plot suffix (e.g., "transient_response", "logic_response").
If None, it is inferred from
TYPE:
|
title
|
Header label for the printed inspection.
TYPE:
|
**kwargs
|
Passed through to the selected plotting function.
DEFAULT:
|
| RETURNS | DESCRIPTION |
|---|---|
Any
|
The same CRN object (for convenience). |
inspect_best(*, plot=True, plot_type=None, **kwargs)
Inspect the current best CRN in the Hall of Fame.
inspect_hof(idx, *, plot=True, plot_type=None, sort_by_reward=True, **kwargs)
Inspect a Hall-of-Fame CRN by index.
| PARAMETER | DESCRIPTION |
|---|---|
idx
|
Index into the HoF list. If
TYPE:
|
plot
|
If True, plot (if possible).
TYPE:
|
plot_type
|
Optional plot suffix; inferred from task kind if None.
TYPE:
|
sort_by_reward
|
If True, sort HoF by
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Optional[Any]
|
Selected CRN if available, else None. |
save(path)
Save a training checkpoint.
| PARAMETER | DESCRIPTION |
|---|---|
path
|
File path to save.
TYPE:
|
load(path, strict=True)
Load a training checkpoint.
| PARAMETER | DESCRIPTION |
|---|---|
path
|
File path to load.
TYPE:
|
strict
|
Passed through to policy.load_state_dict.
TYPE:
|
| RAISES | DESCRIPTION |
|---|---|
FileNotFoundError
|
If checkpoint file does not exist. |
loaded_hof()
Return hall-of-fame CRNs loaded from a checkpoint.
| RETURNS | DESCRIPTION |
|---|---|
Optional[List[Any]]
|
List of CRN objects if present, else None. |
get_sampled_crns()
Return CRN states from the current sample HoF (best->worst).
TaskKindBase
Bases: ABC
Abstract base class for task-kind implementations.
Each task kind encapsulates
- validation of required parameters
- default semantics for inputs (u_list)
- construction of weights / reward function
Defaults must live here, NOT in build_u_list().
help()
staticmethod
Describe the expected params dictionary for this task kind.
| RETURNS | DESCRIPTION |
|---|---|
Dict[str, Any]
|
Dictionary describing required/optional keys and any notes. |
pretty_help(*, width=100, bullet='-', return_str=False)
classmethod
Pretty-print the task-kind help specification in a Markdown-like list format.
This uses cls.help() (a static method implemented by each TaskKind).
The expected shape is:
{
"required": {<key>: <description>, ...},
"optional": {<key>: <description>, ...},
"notes": <string or list of strings>
}
| PARAMETER | DESCRIPTION |
|---|---|
width
|
Maximum line width for wrapping descriptions.
TYPE:
|
bullet
|
Bullet marker to use for list items (default "-").
TYPE:
|
return_str
|
If True, return the formatted string instead of printing.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Optional[str]
|
If return_str=True, returns the formatted help string. Otherwise None. |
validate(task)
Validate that the TaskSpec contains required fields.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If required fields are missing or inconsistent. |
build_time_horizon(task)
Build or reuse the time horizon.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ndarray
|
Time grid array of shape (n_t,) float32. |
default_u_list(task)
abstractmethod
Default semantics for generating u_list for this kind.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
List[ndarray]
|
List of float32 input vectors, each shape (p,). |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If required params are missing for default generation. |
build_u_list(task, overrides)
Build or override the u_list for evaluation.
Precedence
overrides['u_list'] overrides['u_spec'] task.u_list (if user provided explicit list) task.u_spec (special tags only) TaskKind.default_u_list(task) # kind-specific semantics
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
overrides
|
Override dictionary.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
List[ndarray]
|
List of input vectors (float32 arrays), each shape (p,). |
build_ic(task, overrides)
Build the IC object from spec or override.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
overrides
|
Override dictionary, may contain 'ic_spec'.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Any
|
RL4CRN IC object. |
build_weights(task, overrides)
Build weights if needed by the task kind.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
overrides
|
Override dictionary, may contain 'weights_spec'.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Optional[ndarray]
|
Weight matrix or None. |
make_reward_fn(task, overrides)
abstractmethod
Construct reward function for this task kind.
get_device(prefer='auto')
Select a torch device string.
| PARAMETER | DESCRIPTION |
|---|---|
prefer
|
Device preference. Options: - "auto": choose "cuda" if available, else "cpu" - "cpu": force CPU - "cuda": force CUDA (raises if not available)
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
str
|
Device string ("cpu" or "cuda"). |
| RAISES | DESCRIPTION |
|---|---|
RuntimeError
|
If prefer="cuda" but CUDA is not available. |
ValueError
|
If prefer is not one of {"auto", "cpu", "cuda"}. |
seed_everything(seed)
Seed common RNG sources for reproducibility.
| PARAMETER | DESCRIPTION |
|---|---|
seed
|
Random seed.
TYPE:
|
make_time_grid(t_f=100.0, n_t=1000)
Create a uniform time grid.
| PARAMETER | DESCRIPTION |
|---|---|
t_f
|
Final time.
TYPE:
|
n_t
|
Number of time points.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ndarray
|
Time grid as float32 array of shape (n_t,). |
build_u_list(kind, *, n_inputs=None, u_values=None, dose_range=None, u_spec=None)
Construct a list of inputs for a task kind.
| PARAMETER | DESCRIPTION |
|---|---|
kind
|
Task kind.
TYPE:
|
n_inputs
|
Number of input channels.
TYPE:
|
u_values
|
Values to enumerate for grid tasks.
TYPE:
|
dose_range
|
(u_min, u_max, n) for "dose_response" tasks.
TYPE:
|
u_spec
|
Optional escape hatch specifying exact input generation: - ("custom", u_list) - ("grid", values) - ("linspace", u_min, u_max, n)
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
List[ndarray]
|
List of input vectors (float32 arrays). |
build_ic(species_labels, ic_spec)
Build an RL4CRN IC object from a compact spec.
| PARAMETER | DESCRIPTION |
|---|---|
species_labels
|
Species names for the CRN.
TYPE:
|
ic_spec
|
One of: - "zero" - ("constant", value) - ("values", values_2d)
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Any
|
RL4CRN IC instance. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If ic_spec is unknown. |
build_weights(q, n_t, w_spec)
Build a weight matrix for tracking losses.
| PARAMETER | DESCRIPTION |
|---|---|
q
|
Output dimension (usually 1).
TYPE:
|
n_t
|
Number of time points.
TYPE:
|
w_spec
|
One of: - "steady_state": weight only last time point - "uniform": all ones - "transient": bias early/late times - ("custom", array_like)
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ndarray
|
Weight matrix of shape (q, n_t) float32. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If w_spec is unknown. |
make_task(template_crn, library_components, kind, species_labels, *, params=None)
Create a TaskSpec from a params dictionary and build its reward callable.
This is the ONLY public constructor: users pass task knobs via params.
Default interpretation of missing fields is delegated to the TaskKind handler.
Common params keys (shared across many tasks): - "t_f": float - "n_t": int - "n_inputs": int (defaults to template_crn.num_inputs) - "ic": Union[str, tuple] # e.g. "zero", ("constant", 0.01) - "weights": Union[str, tuple] # e.g. "transient", ("custom", ...) - "u_spec": tuple # only for special-tag generation - "u_list": List[np.ndarray] # explicit scenarios
Task-specific keys are documented by TaskKind.help().
| PARAMETER | DESCRIPTION |
|---|---|
template_crn
|
Compiled IOCRN template.
TYPE:
|
library_components
|
Tuple (library, M, K, masks).
TYPE:
|
kind
|
Task kind string.
TYPE:
|
species_labels
|
Species labels used by the task.
TYPE:
|
params
|
Task configuration dictionary.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
TaskSpec
|
TaskSpec with runtime fields (time_horizon/u_list/ic/weights/compute_reward) populated. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If required parameters are missing or inconsistent. |
make_reward_fn_with_overrides(task, *, u_list=None, ic_spec=None, weights_spec=None, **kwargs)
Build a reward function for a TaskSpec, optionally overriding conditions.
This is the single entry point used by training, sampling, resimulation, and load.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
Base TaskSpec.
TYPE:
|
u_list
|
Optional replacement list of input vectors.
TYPE:
|
ic_spec
|
Optional IC spec override.
TYPE:
|
weights_spec
|
Optional weights spec override.
TYPE:
|
**kwargs
|
Additional task-kind-specific overrides.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Callable[[Any], Union[float, Tuple[float, Dict[str, Any]]]]
|
Reward callable accepting a CRN state and returning loss or (loss, info). |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If task.kind is unknown or required fields are missing. |
build_envs(template, max_added_reactions, batch_size, hall_of_fame_size, n_cpus, logger=None)
Create parallel environments.
| PARAMETER | DESCRIPTION |
|---|---|
template
|
IOCRN template.
TYPE:
|
max_added_reactions
|
Episode length.
TYPE:
|
batch_size
|
Number of environments.
TYPE:
|
hall_of_fame_size
|
Hall-of-fame capacity.
TYPE:
|
n_cpus
|
Number of CPUs for parallel execution.
TYPE:
|
logger
|
Optional logger.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
|
ParallelEnvironments instance. |
build_interfaces(library, device, allow_input_influence=False)
Build standard env<->agent interfaces.
| PARAMETER | DESCRIPTION |
|---|---|
library
|
Reaction library.
TYPE:
|
device
|
Torch device string.
TYPE:
|
allow_input_influence
|
Whether to allow input influence features.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
|
Tuple (observer, tensorizer, actuator, stepper). |
build_policy(M, K, p, masks, device, policy_cfg, target_set_size)
Build the policy instance.
| PARAMETER | DESCRIPTION |
|---|---|
M
|
Number of reactions in the library.
TYPE:
|
K
|
Number of total library parameters.
TYPE:
|
p
|
Number of input channels in CRN.
TYPE:
|
masks
|
Parameter/logit masks from the library.
TYPE:
|
device
|
Torch device string.
TYPE:
|
policy_cfg
|
PolicyCfg instance.
TYPE:
|
target_set_size
|
Required for ordered policy.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
|
Policy instance. |
build_agent(policy, device, agent_cfg, logger=None)
Build the REINFORCE(+SIL) agent.
| PARAMETER | DESCRIPTION |
|---|---|
policy
|
Policy instance.
TYPE:
|
device
|
Torch device string.
TYPE:
|
agent_cfg
|
AgentCfg instance.
TYPE:
|
logger
|
Optional logger.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
|
REINFORCEAgent instance. |
make_session_and_trainer(cfg, task, device='auto', logger=None)
print_task_summary(task, max_preview=3)
Compact TaskSpec summary.
run_smoke_reward(task, state, label='')
Call task.compute_reward on a given state and print normalized output.
load_session_and_trainer(checkpoint_path, *, task=None, device='auto', strict=True)
Load a checkpoint and reconstruct a working Trainer.
This convenience function rebuilds the Session/Trainer wiring from scratch and then applies checkpoint state (policy weights, training state, HoFs, RNG state). It also rebuilds runtime-only callables (e.g., task.compute_reward).
Notes
- If
taskis provided, it is used as the task definition and the checkpoint policy weights/state are loaded onto it. - If
taskis None, this function expects the checkpoint'sconfigto contain a serializable TaskSpec underconfig['task'].
| PARAMETER | DESCRIPTION |
|---|---|
checkpoint_path
|
Path to the checkpoint file created by
TYPE:
|
task
|
Optional TaskSpec to use instead of the checkpoint's saved task.
DEFAULT:
|
device
|
Device preference ("auto", "cpu", "cuda").
TYPE:
|
strict
|
Whether to strictly enforce key matching in
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
|
Trainer object fully reconstructed and ready to use. |
| RAISES | DESCRIPTION |
|---|---|
FileNotFoundError
|
If checkpoint_path does not exist. |
KeyError
|
If required keys are missing and |
ValueError
|
If task reconstruction fails. |
register_task_kind(cls)
Register a TaskKindBase subclass into the global registry.
| PARAMETER | DESCRIPTION |
|---|---|
cls
|
TaskKind class.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
type[TaskKindBase]
|
The same class for decorator usage. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If class does not define 'kind' or kind duplicates. |
get_task_kind(kind)
Instantiate a task-kind handler by name.
| PARAMETER | DESCRIPTION |
|---|---|
kind
|
Task kind string.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
TaskKindBase
|
Instance of a TaskKindBase subclass. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If kind is unknown. |
overrides_get(task, overrides, key, *, fallback_attr=None, default=None)
Resolve a parameter using precedence overrides > task.params > task.
| PARAMETER | DESCRIPTION |
|---|---|
task
|
TaskSpec instance.
TYPE:
|
overrides
|
Override dictionary.
TYPE:
|
key
|
Key to search in overrides/task.params.
TYPE:
|
fallback_attr
|
If provided, also search task.
TYPE:
|
default
|
Default if not found.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Any
|
Resolved value or default. |