Skip to content

input_interface

RL4CRN.utils.input_interface

User-facing input interface utilities for GenAI-Net / RL4CRN tutorials.

This module provides: - lightweight configuration objects with sensible defaults - a configurator to apply presets and overrides - a session builder that wires together task/template/library/env/interfaces/policy/agent - a trainer that supports chunked training, early stopping (Ctrl+C), and save/load checkpoints

The goal is to make tutorial notebooks trivial to run, while keeping all advanced knobs discoverable via config inspection.

TaskSpec dataclass

Fully materialized task description used by environments.

ATTRIBUTE DESCRIPTION
template_crn

Compiled IOCRN template.

TYPE: IOCRN

library_components

Tuple (library, M, K, masks).

TYPE: tuple[ReactionLibrary, int, int, dict[str, Any]]

species_labels

Species labels used by the task.

TYPE: List[str]

kind

Task kind (e.g., "logic", "tracking", "oscillator_mean", ...).

TYPE: str

t_f

Final simulation time.

TYPE: float

n_t

Number of time points.

TYPE: int

time_horizon

1D array of time points (float32).

TYPE: ndarray

n_inputs

Number of input channels.

TYPE: Optional[int]

u_values

Values for grid tasks (tracking/oscillator/SSA).

TYPE: Optional[List[float]]

dose_range

(u_min, u_max, n) for "dose_response".

TYPE: Optional[Tuple[float, float, int]]

u_spec

Optional input generation spec.

TYPE: Optional[tuple]

u_list

List of input vectors (each shape (p,), float32).

TYPE: List[ndarray]

ic_spec

IC specification used to build the IC object.

TYPE: Union[str, tuple]

ic

RL4CRN IC object.

TYPE: Any

weights_spec

Weight spec used to build the weight matrix (when applicable).

TYPE: Union[str, tuple]

weights

Weight matrix (when applicable).

TYPE: Optional[ndarray]

target

Target spec for tracking/SSA tasks.

TYPE: Optional[Union[str, float]]

logic_fn

Boolean logic function for "logic".

TYPE: Optional[VectorLogic]

target_fn

Target function for dose response.

TYPE: Optional[Callable[[float], float]]

osc_w

Oscillation error weights.

TYPE: Optional[List[float]]

t0

Oscillation error start time.

TYPE: float

n_trajectories

SSA number of trajectories.

TYPE: int

max_threads

SSA max threads.

TYPE: int

cv_weight

Robust SSA CV weight.

TYPE: float

rpa_weight

Robust SSA RPA weight.

TYPE: float

relative

Whether to use relative error in SSA rewards.

TYPE: bool

norm

Norm used in tracking losses.

TYPE: int

LARGE_NUMBER

Large penalty scalar used by deterministic rewards.

TYPE: float

LARGE_PENALTY

Large penalty scalar used by SSA rewards (when applicable).

TYPE: float

compute_reward

Reward callable built from this TaskSpec.

TYPE: Optional[Callable[[Any], Union[float, Tuple[float, Dict[str, Any]]]]]

params

Task-kind-specific parameters (forward-compatible extension point).

TYPE: Dict[str, Any]

SolverCfg dataclass

Solver configuration.

ATTRIBUTE DESCRIPTION
algorithm

Solver name (e.g., "CVODE" or "LSODA").

TYPE: str

rtol

Relative tolerance.

TYPE: float

atol

Absolute tolerance.

TYPE: float

TrainCfg dataclass

Training configuration.

ATTRIBUTE DESCRIPTION
epochs

Total number of epochs (you may run in chunks).

TYPE: int

max_added_reactions

Episode length: number of reaction-addition steps.

TYPE: int

render_every

Print progress every N epochs (0 disables).

TYPE: int

hall_of_fame_size

Hall-of-fame capacity in ParallelEnvironments.

TYPE: int

batch_multiplier

Batch size = batch_multiplier * num_cpus (if batch_size is None).

TYPE: int

seed

Random seed for reproducibility.

TYPE: int

n_cpus

CPU count to use. If None, uses os.cpu_count().

TYPE: Optional[int]

batch_size

If provided, overrides auto batch sizing.

TYPE: Optional[int]

PolicyCfg dataclass

Policy network configuration.

ATTRIBUTE DESCRIPTION
width

Hidden size for encoder/heads.

TYPE: int

depth

Number of layers for encoder/heads.

TYPE: int

deep_layer_size

Size of deep layer block (policy-dependent).

TYPE: int

continuous_distribution

Dict describing continuous parameter distribution.

TYPE: Dict[str, Any]

entropy_weights_per_head

Entropy coefficients per head.

TYPE: Dict[str, float]

ordering_enabled

If True, uses ordered reaction addition policy.

TYPE: bool

constraint_strength

Constraint strength for ordered policy.

TYPE: float

zero_reaction_idx

If set, this reaction index is treated as a "no-op" action (allowing multiple re-samples).

TYPE: Optional[int]

stop_flag

Internal flag to indicate if calling a "zero_reaction" is a stopping condition (instead of a no-op).

TYPE: bool

AgentCfg dataclass

Agent configuration.

ATTRIBUTE DESCRIPTION
learning_rate

Optimizer learning rate.

TYPE: float

entropy_scheduler

Scheduler parameters for entropy regularization.

TYPE: Dict[str, Any]

risk_scheduler

Scheduler parameters for risk-sensitive objective.

TYPE: Dict[str, Any]

sil_settings

Self-imitation learning configuration.

TYPE: Dict[str, Any]

RenderCfg dataclass

Rendering configuration.

ATTRIBUTE DESCRIPTION
n_best

Number of top trajectories to render.

TYPE: int

disregarded_percentage

Percentage of trajectories to disregard based on reward (for stochastic tasks).

TYPE: float

mode

Rendering mode, e.g., "transients", "inputs", "final_state", etc.

TYPE: str

Config dataclass

Top-level configuration container.

ATTRIBUTE DESCRIPTION
task

Task configuration.

TYPE: TaskSpec

solver

Solver configuration.

TYPE: SolverCfg

train

Training configuration.

TYPE: TrainCfg

library

Library configuration.

TYPE: TrainCfg

policy

Policy configuration.

TYPE: PolicyCfg

agent

Agent configuration.

TYPE: AgentCfg

render

Rendering configuration.

TYPE: RenderCfg

to_dict()

Convert config to a JSON-serializable dictionary.

RETURNS DESCRIPTION
Dict[str, Any]

Nested dictionary of config values.

describe(width=120)

Pretty-print the full configuration.

PARAMETER DESCRIPTION
width

Print width for formatting.

TYPE: int DEFAULT: 120

Configurator

Helpers to create configs from presets and apply overrides.

preset(name='balanced') staticmethod

Create a config from a named preset.

PARAMETER DESCRIPTION
name

Preset name. Supported: - "fast": small networks, looser tolerances - "balanced": sensible defaults - "quality": larger networks, more capacity - "paper": settings used in the GenAI-Net paper experiments

TYPE: str DEFAULT: 'balanced'

RETURNS DESCRIPTION
Config

Config instance.

RAISES DESCRIPTION
ValueError

If preset name is unknown.

with_overrides(cfg, **overrides) staticmethod

Return a deep-copied config with nested overrides applied.

PARAMETER DESCRIPTION
cfg

Base config.

TYPE: Config

**overrides

Nested dictionaries keyed by top-level sections (task, solver, train, library, policy, agent).

TYPE: Dict[str, Any] DEFAULT: {}

RETURNS DESCRIPTION
Config

New Config with overrides applied.

Session dataclass

Container for all objects needed to run training and inspection.

ATTRIBUTE DESCRIPTION
cfg

Config used to build this session.

TYPE: Config

device

Torch device string.

TYPE: str

n_cpus

Number of CPUs used for parallel rollouts.

TYPE: int

batch_size

Number of parallel environments.

TYPE: int

task

Materialized TaskSpec used to compute rewards.

TYPE: TaskSpec

crn_template

Compiled IOCRN template.

TYPE: Any

species_labels

Species labels for template/library.

TYPE: List[str]

library

Reaction library.

TYPE: Any

M

Number of reactions in library.

TYPE: int

K

Number of parameters in library.

TYPE: int

masks

Parameter/logit masks from the library.

TYPE: Dict[str, Any]

p

Number of CRN input channels.

TYPE: int

mult_env

Parallel environments.

TYPE: Any

observer

Env->agent observer.

TYPE: Any

tensorizer

Observer tensorizer.

TYPE: Any

actuator

Agent->env actuator.

TYPE: Any

stepper

Environment stepper.

TYPE: Any

policy

Policy instance.

TYPE: Any

agent

Agent instance.

TYPE: Any

sample_hof

an HallOfFame of CRNs from the hall of fame, populated after calling sample.

TYPE: Optional[HallOfFame]

from_config(cfg, task, device=None, logger=None) staticmethod

Build a Session from a Config.

PARAMETER DESCRIPTION
cfg

Configuration object.

TYPE: Config

task

Materialized TaskSpec object.

TYPE: TaskSpec

device

Torch device string. If None, auto-selects.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
'Session'

Initialized Session with all required RL4CRN objects wired up.

sample(n_samples, sample_hof_size, *, u_list=None, u_spec=None, u_values=None, dose_range=None, ic=None, weights=None)

Sample CRNs from the current policy without training (evaluation-only).

This method creates a temporary batch of environments, performs one rollout (episode) per environment using the current policy in eval mode, computes rewards, and stores the best sampled environments in a dedicated sample_hof HallOfFame.

Sampling does not perform any learning updates (no backpropagation).

Calling this method again replaces the previously stored sample_hof, so that different checkpoints can store different sample sets.

PARAMETER DESCRIPTION
n_samples

Number of environments to roll out (number of samples drawn).

TYPE: int

sample_hof_size

Capacity of the sample HallOfFame (best K kept).

TYPE: int

u_list

Optional explicit list of input vectors to evaluate.

TYPE: Optional[List[ndarray]] DEFAULT: None

u_spec

Optional input generation spec (same as build_u_list): ("custom", u_list), ("grid", values), ("linspace", u_min, u_max, n)

TYPE: Optional[tuple] DEFAULT: None

u_values

Optional enumerated values used by build_u_list for grid tasks.

TYPE: Optional[List[float]] DEFAULT: None

dose_range

Optional (u_min, u_max, n) for dose_response input generation.

TYPE: Optional[Tuple[float, float, int]] DEFAULT: None

ic

Optional IC spec override (same format accepted by build_ic).

TYPE: Optional[Union[str, tuple]] DEFAULT: None

weights

Optional weights spec override (same format accepted by build_weights).

TYPE: Optional[Union[str, tuple]] DEFAULT: None

RETURNS DESCRIPTION
HallOfFame

The newly created sample HallOfFame containing sampled env snapshots.

TYPE: HallOfFame

RAISES DESCRIPTION
ValueError

If n_samples/sample_hof_size are invalid or input dimension mismatch.

TrainState dataclass

Training state that persists across chunked runs.

ATTRIBUTE DESCRIPTION
epoch

Next epoch index to run.

TYPE: int

history

List of dicts with keys {"epoch","best","median"}.

TYPE: List[Dict[str, float]]

Trainer

Chunkable trainer with stop/resume and checkpointing.

__init__(session)

Initialize trainer.

PARAMETER DESCRIPTION
session

Built Session containing envs, agent, and task reward function.

TYPE: Session

resimulate(crns, *, task=None, u_list=None, u_spec=None, u_values=None, dose_range=None, ic=None, weights=None, n_cpus=None)

Clone and re-simulate CRNs under a task, optionally overriding conditions.

This is mainly for re-evaluating existing CRNs (e.g., from the training Hall of Fame) under new experimental conditions such as different input scenarios or initial conditions, without mutating the original CRN objects.

The method clones each CRN via .clone() before simulation, runs task reward evaluation (which triggers transient simulations internally), and returns the cloned CRNs with updated last_task_info.

PARAMETER DESCRIPTION
crns

List of CRN objects to re-simulate. Each must implement .clone().

TYPE: List[Any]

task

Optional TaskSpec to use. If None, defaults to self.s.task.

TYPE: Optional[TaskSpec] DEFAULT: None

u_list

Optional explicit list of input vectors for evaluation.

TYPE: Optional[List[ndarray]] DEFAULT: None

u_spec

Optional input generation spec (same as build_u_list), used if u_list is None.

TYPE: Optional[tuple] DEFAULT: None

u_values

Optional enumerated values used by build_u_list for grid tasks.

TYPE: Optional[List[float]] DEFAULT: None

dose_range

Optional (u_min, u_max, n) for dose_response input generation.

TYPE: Optional[Tuple[float, float, int]] DEFAULT: None

ic

Optional IC spec override (same format accepted by build_ic).

TYPE: Optional[Union[str, tuple]] DEFAULT: None

weights

Optional weights spec override (same format accepted by build_weights).

TYPE: Optional[Union[str, tuple]] DEFAULT: None

n_cpus

Optional CPU override for evaluation. If None, uses self.s.n_cpus.

TYPE: Optional[int] DEFAULT: None

RETURNS DESCRIPTION
List[Any]

List of cloned CRNs after evaluation. The returned CRNs have fresh last_task_info

List[Any]

corresponding to this re-simulation.

RAISES DESCRIPTION
ValueError

If CRNs do not support .clone() or inputs have inconsistent dimensions.

step_epoch()

Run a single epoch: rollout, reward eval, and policy update.

RETURNS DESCRIPTION
Tuple[float, float]

Tuple (best_loss, median_loss) over the batch.

run(epochs, checkpoint_path=None)

Run training for a chunk of epochs.

PARAMETER DESCRIPTION
epochs

Number of epochs to run in this chunk.

TYPE: int

checkpoint_path

If provided, saves a checkpoint periodically and on interrupt.

TYPE: Optional[str] DEFAULT: None

best_crn()

Return the best CRN currently in the hall of fame.

RETURNS DESCRIPTION
Optional[Any]

Best CRN object if available, else None.

sample(n_samples, sample_hof_size, *, u_list=None, u_spec=None, u_values=None, dose_range=None, ic=None, weights=None)

Convenience wrapper around Session.sample to sample from the current policy.

inspect(crn, *, plot=True, plot_type=None, title='CRN', **kwargs)

Print and optionally plot a given CRN.

PARAMETER DESCRIPTION
crn

The CRN object to inspect.

TYPE: Any

plot

If True, call the appropriate plot method on the CRN (if available).

TYPE: bool DEFAULT: True

plot_type

Optional plot suffix (e.g., "transient_response", "logic_response"). If None, it is inferred from self.s.cfg.task.kind.

TYPE: Optional[str] DEFAULT: None

title

Header label for the printed inspection.

TYPE: str DEFAULT: 'CRN'

**kwargs

Passed through to the selected plotting function.

DEFAULT: {}

RETURNS DESCRIPTION
Any

The same CRN object (for convenience).

inspect_best(*, plot=True, plot_type=None, **kwargs)

Inspect the current best CRN in the Hall of Fame.

inspect_hof(idx, *, plot=True, plot_type=None, sort_by_reward=True, **kwargs)

Inspect a Hall-of-Fame CRN by index.

PARAMETER DESCRIPTION
idx

Index into the HoF list. If sort_by_reward=True, index is taken after sorting by ascending reward.

TYPE: int

plot

If True, plot (if possible).

TYPE: bool DEFAULT: True

plot_type

Optional plot suffix; inferred from task kind if None.

TYPE: Optional[str] DEFAULT: None

sort_by_reward

If True, sort HoF by last_task_info['reward'] ascending.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Optional[Any]

Selected CRN if available, else None.

save(path)

Save a training checkpoint.

PARAMETER DESCRIPTION
path

File path to save.

TYPE: str

load(path, strict=True)

Load a training checkpoint.

PARAMETER DESCRIPTION
path

File path to load.

TYPE: str

strict

Passed through to policy.load_state_dict.

TYPE: bool DEFAULT: True

RAISES DESCRIPTION
FileNotFoundError

If checkpoint file does not exist.

loaded_hof()

Return hall-of-fame CRNs loaded from a checkpoint.

RETURNS DESCRIPTION
Optional[List[Any]]

List of CRN objects if present, else None.

get_sampled_crns()

Return CRN states from the current sample HoF (best->worst).

TaskKindBase

Bases: ABC

Abstract base class for task-kind implementations.

Each task kind encapsulates
  • validation of required parameters
  • default semantics for inputs (u_list)
  • construction of weights / reward function

Defaults must live here, NOT in build_u_list().

help() staticmethod

Describe the expected params dictionary for this task kind.

RETURNS DESCRIPTION
Dict[str, Any]

Dictionary describing required/optional keys and any notes.

pretty_help(*, width=100, bullet='-', return_str=False) classmethod

Pretty-print the task-kind help specification in a Markdown-like list format.

This uses cls.help() (a static method implemented by each TaskKind). The expected shape is:

{
  "required": {<key>: <description>, ...},
  "optional": {<key>: <description>, ...},
  "notes": <string or list of strings>
}
PARAMETER DESCRIPTION
width

Maximum line width for wrapping descriptions.

TYPE: int DEFAULT: 100

bullet

Bullet marker to use for list items (default "-").

TYPE: str DEFAULT: '-'

return_str

If True, return the formatted string instead of printing.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Optional[str]

If return_str=True, returns the formatted help string. Otherwise None.

validate(task)

Validate that the TaskSpec contains required fields.

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

RAISES DESCRIPTION
ValueError

If required fields are missing or inconsistent.

build_time_horizon(task)

Build or reuse the time horizon.

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

RETURNS DESCRIPTION
ndarray

Time grid array of shape (n_t,) float32.

default_u_list(task) abstractmethod

Default semantics for generating u_list for this kind.

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

RETURNS DESCRIPTION
List[ndarray]

List of float32 input vectors, each shape (p,).

RAISES DESCRIPTION
ValueError

If required params are missing for default generation.

build_u_list(task, overrides)

Build or override the u_list for evaluation.

Precedence

overrides['u_list'] overrides['u_spec'] task.u_list (if user provided explicit list) task.u_spec (special tags only) TaskKind.default_u_list(task) # kind-specific semantics

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

overrides

Override dictionary.

TYPE: Dict[str, Any]

RETURNS DESCRIPTION
List[ndarray]

List of input vectors (float32 arrays), each shape (p,).

build_ic(task, overrides)

Build the IC object from spec or override.

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

overrides

Override dictionary, may contain 'ic_spec'.

TYPE: Dict[str, Any]

RETURNS DESCRIPTION
Any

RL4CRN IC object.

build_weights(task, overrides)

Build weights if needed by the task kind.

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

overrides

Override dictionary, may contain 'weights_spec'.

TYPE: Dict[str, Any]

RETURNS DESCRIPTION
Optional[ndarray]

Weight matrix or None.

make_reward_fn(task, overrides) abstractmethod

Construct reward function for this task kind.

get_device(prefer='auto')

Select a torch device string.

PARAMETER DESCRIPTION
prefer

Device preference. Options: - "auto": choose "cuda" if available, else "cpu" - "cpu": force CPU - "cuda": force CUDA (raises if not available)

TYPE: str DEFAULT: 'auto'

RETURNS DESCRIPTION
str

Device string ("cpu" or "cuda").

RAISES DESCRIPTION
RuntimeError

If prefer="cuda" but CUDA is not available.

ValueError

If prefer is not one of {"auto", "cpu", "cuda"}.

seed_everything(seed)

Seed common RNG sources for reproducibility.

PARAMETER DESCRIPTION
seed

Random seed.

TYPE: int

make_time_grid(t_f=100.0, n_t=1000)

Create a uniform time grid.

PARAMETER DESCRIPTION
t_f

Final time.

TYPE: float DEFAULT: 100.0

n_t

Number of time points.

TYPE: int DEFAULT: 1000

RETURNS DESCRIPTION
ndarray

Time grid as float32 array of shape (n_t,).

build_u_list(kind, *, n_inputs=None, u_values=None, dose_range=None, u_spec=None)

Construct a list of inputs for a task kind.

PARAMETER DESCRIPTION
kind

Task kind.

TYPE: str

n_inputs

Number of input channels.

TYPE: Optional[int] DEFAULT: None

u_values

Values to enumerate for grid tasks.

TYPE: Optional[List[float]] DEFAULT: None

dose_range

(u_min, u_max, n) for "dose_response" tasks.

TYPE: Optional[Tuple[float, float, int]] DEFAULT: None

u_spec

Optional escape hatch specifying exact input generation: - ("custom", u_list) - ("grid", values) - ("linspace", u_min, u_max, n)

TYPE: Optional[tuple] DEFAULT: None

RETURNS DESCRIPTION
List[ndarray]

List of input vectors (float32 arrays).

build_ic(species_labels, ic_spec)

Build an RL4CRN IC object from a compact spec.

PARAMETER DESCRIPTION
species_labels

Species names for the CRN.

TYPE: List[str]

ic_spec

One of: - "zero" - ("constant", value) - ("values", values_2d)

TYPE: Union[str, tuple]

RETURNS DESCRIPTION
Any

RL4CRN IC instance.

RAISES DESCRIPTION
ValueError

If ic_spec is unknown.

build_weights(q, n_t, w_spec)

Build a weight matrix for tracking losses.

PARAMETER DESCRIPTION
q

Output dimension (usually 1).

TYPE: int

n_t

Number of time points.

TYPE: int

w_spec

One of: - "steady_state": weight only last time point - "uniform": all ones - "transient": bias early/late times - ("custom", array_like)

TYPE: Union[str, tuple]

RETURNS DESCRIPTION
ndarray

Weight matrix of shape (q, n_t) float32.

RAISES DESCRIPTION
ValueError

If w_spec is unknown.

make_task(template_crn, library_components, kind, species_labels, *, params=None)

Create a TaskSpec from a params dictionary and build its reward callable.

This is the ONLY public constructor: users pass task knobs via params. Default interpretation of missing fields is delegated to the TaskKind handler.

Common params keys (shared across many tasks): - "t_f": float - "n_t": int - "n_inputs": int (defaults to template_crn.num_inputs) - "ic": Union[str, tuple] # e.g. "zero", ("constant", 0.01) - "weights": Union[str, tuple] # e.g. "transient", ("custom", ...) - "u_spec": tuple # only for special-tag generation - "u_list": List[np.ndarray] # explicit scenarios

Task-specific keys are documented by TaskKind.help().

PARAMETER DESCRIPTION
template_crn

Compiled IOCRN template.

TYPE: IOCRN

library_components

Tuple (library, M, K, masks).

TYPE: tuple[ReactionLibrary, int, int, dict[str, Any]]

kind

Task kind string.

TYPE: str

species_labels

Species labels used by the task.

TYPE: List[str]

params

Task configuration dictionary.

TYPE: Optional[Dict[str, Any]] DEFAULT: None

RETURNS DESCRIPTION
TaskSpec

TaskSpec with runtime fields (time_horizon/u_list/ic/weights/compute_reward) populated.

RAISES DESCRIPTION
ValueError

If required parameters are missing or inconsistent.

make_reward_fn_with_overrides(task, *, u_list=None, ic_spec=None, weights_spec=None, **kwargs)

Build a reward function for a TaskSpec, optionally overriding conditions.

This is the single entry point used by training, sampling, resimulation, and load.

PARAMETER DESCRIPTION
task

Base TaskSpec.

TYPE: TaskSpec

u_list

Optional replacement list of input vectors.

TYPE: Optional[List[ndarray]] DEFAULT: None

ic_spec

Optional IC spec override.

TYPE: Optional[Union[str, tuple]] DEFAULT: None

weights_spec

Optional weights spec override.

TYPE: Optional[Union[str, tuple]] DEFAULT: None

**kwargs

Additional task-kind-specific overrides.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Callable[[Any], Union[float, Tuple[float, Dict[str, Any]]]]

Reward callable accepting a CRN state and returning loss or (loss, info).

RAISES DESCRIPTION
ValueError

If task.kind is unknown or required fields are missing.

build_envs(template, max_added_reactions, batch_size, hall_of_fame_size, n_cpus, logger=None)

Create parallel environments.

PARAMETER DESCRIPTION
template

IOCRN template.

TYPE: Any

max_added_reactions

Episode length.

TYPE: int

batch_size

Number of environments.

TYPE: int

hall_of_fame_size

Hall-of-fame capacity.

TYPE: int

n_cpus

Number of CPUs for parallel execution.

TYPE: int

logger

Optional logger.

TYPE: Any DEFAULT: None

RETURNS DESCRIPTION

ParallelEnvironments instance.

build_interfaces(library, device, allow_input_influence=False)

Build standard env<->agent interfaces.

PARAMETER DESCRIPTION
library

Reaction library.

TYPE: Any

device

Torch device string.

TYPE: str

allow_input_influence

Whether to allow input influence features.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION

Tuple (observer, tensorizer, actuator, stepper).

build_policy(M, K, p, masks, device, policy_cfg, target_set_size)

Build the policy instance.

PARAMETER DESCRIPTION
M

Number of reactions in the library.

TYPE: int

K

Number of total library parameters.

TYPE: int

p

Number of input channels in CRN.

TYPE: int

masks

Parameter/logit masks from the library.

TYPE: Dict[str, Any]

device

Torch device string.

TYPE: str

policy_cfg

PolicyCfg instance.

TYPE: PolicyCfg

target_set_size

Required for ordered policy.

TYPE: int

RETURNS DESCRIPTION

Policy instance.

build_agent(policy, device, agent_cfg, logger=None)

Build the REINFORCE(+SIL) agent.

PARAMETER DESCRIPTION
policy

Policy instance.

TYPE: Any

device

Torch device string.

TYPE: str

agent_cfg

AgentCfg instance.

TYPE: AgentCfg

logger

Optional logger.

TYPE: Any DEFAULT: None

RETURNS DESCRIPTION

REINFORCEAgent instance.

make_session_and_trainer(cfg, task, device='auto', logger=None)

Convenience function to build a session and trainer.

PARAMETER DESCRIPTION
cfg

Configuration.

TYPE: Config

task

Materialized TaskSpec object.

TYPE: TaskSpec

device

Device preference ("auto", "cpu", or "cuda").

TYPE: str DEFAULT: 'auto'

RETURNS DESCRIPTION
Tuple[Session, Trainer]

Trainer object.

print_task_summary(task, max_preview=3)

Compact TaskSpec summary.

run_smoke_reward(task, state, label='')

Call task.compute_reward on a given state and print normalized output.

load_session_and_trainer(checkpoint_path, *, task=None, device='auto', strict=True)

Load a checkpoint and reconstruct a working Trainer.

This convenience function rebuilds the Session/Trainer wiring from scratch and then applies checkpoint state (policy weights, training state, HoFs, RNG state). It also rebuilds runtime-only callables (e.g., task.compute_reward).

Notes
  • If task is provided, it is used as the task definition and the checkpoint policy weights/state are loaded onto it.
  • If task is None, this function expects the checkpoint's config to contain a serializable TaskSpec under config['task'].
PARAMETER DESCRIPTION
checkpoint_path

Path to the checkpoint file created by Trainer.save.

TYPE: str

task

Optional TaskSpec to use instead of the checkpoint's saved task.

DEFAULT: None

device

Device preference ("auto", "cpu", "cuda").

TYPE: str DEFAULT: 'auto'

strict

Whether to strictly enforce key matching in load_state_dict.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION

Trainer object fully reconstructed and ready to use.

RAISES DESCRIPTION
FileNotFoundError

If checkpoint_path does not exist.

KeyError

If required keys are missing and task is not provided.

ValueError

If task reconstruction fails.

register_task_kind(cls)

Register a TaskKindBase subclass into the global registry.

PARAMETER DESCRIPTION
cls

TaskKind class.

TYPE: type[TaskKindBase]

RETURNS DESCRIPTION
type[TaskKindBase]

The same class for decorator usage.

RAISES DESCRIPTION
ValueError

If class does not define 'kind' or kind duplicates.

get_task_kind(kind)

Instantiate a task-kind handler by name.

PARAMETER DESCRIPTION
kind

Task kind string.

TYPE: str

RETURNS DESCRIPTION
TaskKindBase

Instance of a TaskKindBase subclass.

RAISES DESCRIPTION
ValueError

If kind is unknown.

overrides_get(task, overrides, key, *, fallback_attr=None, default=None)

Resolve a parameter using precedence overrides > task.params > task..

PARAMETER DESCRIPTION
task

TaskSpec instance.

TYPE: TaskSpec

overrides

Override dictionary.

TYPE: Dict[str, Any]

key

Key to search in overrides/task.params.

TYPE: str

fallback_attr

If provided, also search task..

TYPE: Optional[str] DEFAULT: None

default

Default if not found.

TYPE: Any DEFAULT: None

RETURNS DESCRIPTION
Any

Resolved value or default.