Skip to content

hremd #

Perform Hamiltonian replica exchange sampling.

Functions:

  • run_hremd

    Run a Hamiltonian replica exchange simulation.

run_hremd #

run_hremd(
    simulation: Simulation,
    states: list[dict[str, float]],
    config: HREMD,
    output_dir: Path | None = None,
    swap_mask: set[tuple[int, int]] | None = None,
    force_groups: set[int] | int = -1,
    initial_coords: list[State] | None = None,
    analysis_fn: (
        Callable[[int, ndarray, ndarray], None] | None
    ) = None,
    analysis_interval: int | None = None,
) -> tuple[ndarray, ndarray, list[State]]

Run a Hamiltonian replica exchange simulation.

Parameters:

  • simulation (Simulation) –

    The main simulation object to sample using.

  • states (list[dict[str, float]]) –

    The states to sample at. This should be a dictionary with keys corresponding to global context parameters.

  • config (HREMD) –

    The sampling configuration.

  • output_dir (Path | None, default: None ) –

    The directory to store the sampled energies and statistics to, and any trajectory / checkpoint files if requested in the config. If None, no output of any kind will be written.

  • swap_mask (set[tuple[int, int]] | None, default: None ) –

    Pairs of states that should not be swapped.

  • force_groups (set[int] | int, default: -1 ) –

    The force groups to consider when computing the reduced potentials

  • initial_coords (list[State] | None, default: None ) –

    The initial coordinates of each state. If not provided, the coordinates will be taken from the simulation object.

  • analysis_fn (Callable[[int, ndarray, ndarray], None] | None, default: None ) –

    A function to call after every analysis_interval cycles. It should take as arguments the current cycle number, the reduced potentials with shape=(n_states, n_samples) and the number of samples of each state with shape=(n_states,).

  • analysis_interval (int | None, default: None ) –

    The interval with which to call the analysis function. If None, no analysis will be performed.

Returns:

  • tuple[ndarray, ndarray, list[State]]

    The reduced potentials, the number of samples of each state, and the final coordinates of each state.

Source code in femto/md/hremd.py
def run_hremd(
    simulation: openmm.app.Simulation,
    states: list[dict[str, float]],
    config: femto.md.config.HREMD,
    output_dir: pathlib.Path | None = None,
    swap_mask: set[tuple[int, int]] | None = None,
    force_groups: set[int] | int = -1,
    initial_coords: list[openmm.State] | None = None,
    analysis_fn: typing.Callable[[int, numpy.ndarray, numpy.ndarray], None]
    | None = None,
    analysis_interval: int | None = None,
) -> tuple[numpy.ndarray, numpy.ndarray, list[openmm.State]]:
    """Run a Hamiltonian replica exchange simulation.

    Args:
        simulation: The main simulation object to sample using.
        states: The states to sample at. This should be a dictionary with keys
            corresponding to global context parameters.
        config: The sampling configuration.
        output_dir: The directory to store the sampled energies and statistics to, and
            any trajectory / checkpoint files if requested in the config. If ``None``,
            no output of any kind will be written.
        swap_mask: Pairs of states that should not be swapped.
        force_groups: The force groups to consider when computing the reduced potentials
        initial_coords: The initial coordinates of each state. If not provided, the
            coordinates will be taken from the simulation object.
        analysis_fn: A function to call after every ``analysis_interval`` cycles. It
            should take as arguments the current cycle number, the reduced potentials
            with ``shape=(n_states, n_samples)`` and the number of samples of each
            state with ``shape=(n_states,)``.
        analysis_interval: The interval with which to call the analysis function.
            If ``None``, no analysis will be performed.

    Returns:
        The reduced potentials, the number of samples of each state, and the final
        coordinates of each state.
    """
    from mpi4py import MPI

    n_states = len(states)

    states = [
        femto.md.utils.openmm.evaluate_ctx_parameters(state, simulation.system)
        for state in states
    ]

    swap_mask = set() if swap_mask is None else swap_mask

    n_proposed_swaps = numpy.zeros((n_states, n_states))
    n_accepted_swaps = numpy.zeros((n_states, n_states))

    replica_to_state_idx = numpy.arange(n_states)

    u_kn, n_k = (
        numpy.empty((n_states, n_states * config.n_cycles)),
        numpy.zeros(n_states, dtype=int),
    )
    has_sampled = numpy.zeros(n_states * config.n_cycles, bool)

    pressure = femto.md.utils.openmm.get_pressure(simulation.system)

    samples_path = None if output_dir is None else output_dir / "samples.arrow"
    checkpoint_path = (
        None
        if output_dir is None
        or config.checkpoint_interval is None
        or config.checkpoint_interval <= 0
        else output_dir / "checkpoint.pkl"
    )

    start_cycle = 0

    if checkpoint_path is not None and checkpoint_path.exists():
        (
            start_cycle,
            initial_coords,
            u_kn,
            n_k,
            has_sampled,
            n_proposed_swaps,
            n_accepted_swaps,
            replica_to_state_idx,
        ) = _load_checkpoint(config, n_states, checkpoint_path)
        _LOGGER.info(f"resuming from cycle {start_cycle} samples")

    with (
        femto.md.utils.mpi.get_mpi_comm() as mpi_comm,
        _create_storage(mpi_comm, samples_path, n_states, start_cycle) as storage,
        contextlib.ExitStack() as exit_stack,
    ):
        # each MPI process may be responsible for propagating multiple states,
        # e.g. if we have 20 states to simulate windows but only 4 GPUs to run on.
        n_replicas, replica_idx_offset = femto.md.utils.mpi.divide_tasks(
            mpi_comm, n_states
        )

        if initial_coords is None:
            coords_0 = simulation.context.getState(
                getPositions=True, enforcePeriodicBox=config.trajectory_enforce_pbc
            )
            coords = [coords_0] * n_replicas
        else:
            coords = [initial_coords[i + replica_idx_offset] for i in range(n_replicas)]

        if start_cycle == 0:
            if mpi_comm.rank == 0:
                _LOGGER.info(f"running {config.n_warmup_steps} warm-up steps")

            _propagate_replicas(
                simulation,
                config.temperature,
                pressure,
                states,
                coords,
                config.n_warmup_steps,
                replica_to_state_idx,
                replica_idx_offset,
                force_groups,
                config.max_step_retries,
                config.trajectory_enforce_pbc,
            )

            if mpi_comm.rank == 0:
                _LOGGER.info(f"running {config.n_cycles} replica exchange cycles")

        trajectory_storage = _create_trajectory_storage(
            simulation,
            n_replicas,
            replica_idx_offset,
            config.n_steps_per_cycle,
            start_cycle,
            config.trajectory_interval,
            output_dir,
            exit_stack,
        )

        for cycle in tqdm.tqdm(
            range(start_cycle, config.n_cycles),
            total=config.n_cycles - start_cycle,
            initial=start_cycle,
            disable=mpi_comm.rank != 0,
        ):
            reduced_potentials = _propagate_replicas(
                simulation,
                config.temperature,
                pressure,
                states,
                coords,
                config.n_steps_per_cycle,
                replica_to_state_idx,
                replica_idx_offset,
                force_groups,
                config.max_step_retries,
                config.trajectory_enforce_pbc,
            )
            reduced_potentials = mpi_comm.reduce(reduced_potentials, MPI.SUM, 0)

            has_sampled[replica_to_state_idx * config.n_cycles + cycle] = True
            u_kn[:, replica_to_state_idx * config.n_cycles + cycle] = reduced_potentials

            n_k += 1

            should_save_trajectory = (
                config.trajectory_interval is not None
                and cycle % config.trajectory_interval == 0
            )

            if should_save_trajectory:
                mpi_comm.barrier()
                _store_trajectory(coords, trajectory_storage)

            should_analyze = (
                analysis_fn is not None
                and mpi_comm.rank == 0
                and analysis_interval is not None
                and cycle % analysis_interval == 0
            )

            if should_analyze:
                analysis_fn(cycle, u_kn[:, has_sampled], n_k)

            _store_potentials(
                replica_to_state_idx,
                reduced_potentials,
                n_proposed_swaps,
                n_accepted_swaps,
                storage,
                cycle * config.n_steps_per_cycle,
            )
            _propose_swaps(
                replica_to_state_idx,
                reduced_potentials,
                n_proposed_swaps,
                n_accepted_swaps,
                swap_mask,
                config.swap_mode,
                config.max_swaps,
            )

            replica_to_state_idx = mpi_comm.bcast(replica_to_state_idx, 0)

            should_checkpoint = (
                checkpoint_path is not None
                and config.checkpoint_interval is not None
                and (
                    cycle % config.checkpoint_interval == 0
                    or cycle == config.n_cycles - 1
                )
            )

            if should_checkpoint:
                _store_checkpoint(
                    cycle + 1,
                    coords,
                    u_kn,
                    n_k,
                    has_sampled,
                    n_proposed_swaps,
                    n_accepted_swaps,
                    replica_to_state_idx,
                    replica_idx_offset,
                    checkpoint_path,
                    mpi_comm,
                )

        mpi_comm.barrier()

        coords_dict = {i + replica_idx_offset: coord for i, coord in enumerate(coords)}
        coords_dict = femto.md.utils.mpi.reduce_dict(coords_dict, mpi_comm, root=None)

        final_coords = [coords_dict[replica_to_state_idx[i]] for i in range(n_states)]

    return u_kn, n_k, final_coords