Skip to content

hremd #

Perform Hamiltonian replica exchange sampling.

Functions:

  • run_hremd

    Run a Hamiltonian replica exchange simulation.

run_hremd #

run_hremd(
    simulation: Simulation,
    states: list[dict[str, float]],
    config: HREMD,
    output_dir: Path,
    swap_mask: set[tuple[int, int]] | None = None,
    force_groups: set[int] | int = -1,
    initial_coords: list[State] | None = None,
    analysis_fn: Callable[[int, ndarray, ndarray], None]
    | None = None,
    analysis_interval: int | None = None,
)

Run a Hamiltonian replica exchange simulation.

Parameters:

  • simulation (Simulation) –

    The main simulation object to sample using.

  • states (list[dict[str, float]]) –

    The states to sample at. This should be a dictionary with keys corresponding to global context parameters.

  • config (HREMD) –

    The sampling configuration.

  • output_dir (Path) –

    The directory to store the sampled energies and statistics to, and any trajectory files if requested in the config.

  • swap_mask (set[tuple[int, int]] | None, default: None ) –

    Pairs of states that should not be swapped.

  • force_groups (set[int] | int, default: -1 ) –

    The force groups to consider when computing the reduced potentials

  • initial_coords (list[State] | None, default: None ) –

    The initial coordinates of each state. If not provided, the coordinates will be taken from the simulation object.

  • analysis_fn (Callable[[int, ndarray, ndarray], None] | None, default: None ) –

    A function to call after every analysis_interval cycles. It should take as arguments the current cycle number, the reduced potentials with shape=(n_states, n_samples) and the number of samples of each state with shape=(n_states,).

  • analysis_interval (int | None, default: None ) –

    The interval with which to call the analysis function. If None, no analysis will be performed.

Source code in femto/md/hremd.py
def run_hremd(
    simulation: openmm.app.Simulation,
    states: list[dict[str, float]],
    config: femto.md.config.HREMD,
    output_dir: pathlib.Path,
    swap_mask: set[tuple[int, int]] | None = None,
    force_groups: set[int] | int = -1,
    initial_coords: list[openmm.State] | None = None,
    analysis_fn: typing.Callable[[int, numpy.ndarray, numpy.ndarray], None]
    | None = None,
    analysis_interval: int | None = None,
):
    """Run a Hamiltonian replica exchange simulation.

    Args:
        simulation: The main simulation object to sample using.
        states: The states to sample at. This should be a dictionary with keys
            corresponding to global context parameters.
        config: The sampling configuration.
        output_dir: The directory to store the sampled energies and statistics to, and
            any trajectory files if requested in the config.
        swap_mask: Pairs of states that should not be swapped.
        force_groups: The force groups to consider when computing the reduced potentials
        initial_coords: The initial coordinates of each state. If not provided, the
            coordinates will be taken from the simulation object.
        analysis_fn: A function to call after every ``analysis_interval`` cycles. It
            should take as arguments the current cycle number, the reduced potentials
            with ``shape=(n_states, n_samples)`` and the number of samples of each
            state with ``shape=(n_states,)``.
        analysis_interval: The interval with which to call the analysis function.
            If ``None``, no analysis will be performed.
    """
    from mpi4py import MPI

    n_states = len(states)

    states = [
        femto.md.utils.openmm.evaluate_ctx_parameters(state, simulation.system)
        for state in states
    ]

    swap_mask = set() if swap_mask is None else swap_mask

    n_proposed_swaps = numpy.zeros((n_states, n_states))
    n_accepted_swaps = numpy.zeros((n_states, n_states))

    replica_to_state_idx = numpy.arange(n_states)

    u_kn, n_k = (
        numpy.empty((n_states, n_states * config.n_cycles)),
        numpy.zeros(n_states, dtype=int),
    )
    has_sampled = numpy.zeros(n_states * config.n_cycles, bool)

    barostats = [
        force
        for force in simulation.system.getForces()
        if isinstance(force, openmm.MonteCarloBarostat)
    ]
    assert len(barostats) == 0 or len(barostats) == 1

    pressure = (
        None
        if len(barostats) == 0 or barostats[0].getFrequency() <= 0
        else barostats[0].getDefaultPressure()
    )

    samples_path = output_dir / "samples.arrow"

    with (
        femto.md.utils.mpi.get_mpi_comm() as mpi_comm,
        _create_storage(mpi_comm, samples_path, n_states) as storage,
        contextlib.ExitStack() as exit_stack,
    ):
        # each MPI process may be responsible for propagating multiple states,
        # e.g. if we have 20 states to simulate windows but only 4 GPUs to run on.
        n_replicas, replica_idx_offset = femto.md.utils.mpi.divide_tasks(
            mpi_comm, n_states
        )

        if initial_coords is None:
            coords = [simulation.context.getState(getPositions=True)] * n_replicas
        else:
            coords = [initial_coords[i + replica_idx_offset] for i in range(n_replicas)]

        if mpi_comm.rank == 0:
            _LOGGER.info(f"running {config.n_warmup_steps} warm-up steps")

        _propagate_replicas(
            simulation,
            config.temperature,
            pressure,
            states,
            coords,
            config.n_warmup_steps,
            replica_to_state_idx,
            replica_idx_offset,
            force_groups,
            config.max_step_retries,
        )

        if mpi_comm.rank == 0:
            _LOGGER.info(f"running {config.n_cycles} replica exchange cycles")

        trajectory_storage = _create_trajectory_storage(
            simulation,
            n_replicas,
            replica_idx_offset,
            config.n_steps_per_cycle,
            config.trajectory_interval,
            output_dir,
            exit_stack,
        )

        for cycle in tqdm.tqdm(
            range(config.n_cycles), total=config.n_cycles, disable=mpi_comm.rank != 0
        ):
            reduced_potentials = _propagate_replicas(
                simulation,
                config.temperature,
                pressure,
                states,
                coords,
                config.n_steps_per_cycle,
                replica_to_state_idx,
                replica_idx_offset,
                force_groups,
                config.max_step_retries,
            )
            reduced_potentials = mpi_comm.reduce(reduced_potentials, MPI.SUM, 0)

            has_sampled[replica_to_state_idx * config.n_cycles + cycle] = True
            u_kn[:, replica_to_state_idx * config.n_cycles + cycle] = reduced_potentials

            n_k += 1

            should_save_trajectory = (
                config.trajectory_interval is not None
                and cycle % config.trajectory_interval == 0
            )

            if should_save_trajectory:
                _store_trajectory(coords, trajectory_storage)

            should_analyze = (
                analysis_fn is not None
                and analysis_interval is not None
                and cycle % analysis_interval == 0
            )

            if should_analyze:
                analysis_fn(cycle, u_kn[:, has_sampled], n_k)

            if mpi_comm.rank == 0:
                _store_potentials(
                    replica_to_state_idx,
                    reduced_potentials,
                    n_proposed_swaps,
                    n_accepted_swaps,
                    storage,
                    cycle * config.n_steps_per_cycle,
                )
                _propose_swaps(
                    replica_to_state_idx,
                    reduced_potentials,
                    n_proposed_swaps,
                    n_accepted_swaps,
                    swap_mask,
                    config.swap_mode,
                    config.max_swaps,
                )

            replica_to_state_idx = mpi_comm.bcast(replica_to_state_idx, 0)

        mpi_comm.barrier()

    return u_kn, n_k