"""
Manager that manages the jupyter_client's AsyncKernelManager
"""
import asyncio
import logging
import queue
import sys
import uuid
import socketio
from jupyter_client import AsyncKernelManager
from ..state.manager import StateManager
logger = logging.getLogger(__name__)
# pylint: disable=too-many-instance-attributes
[docs]
class KernelManager:
"""Kernel Manager
This class is responsible for managing the AsyncKernelManager (async kernel)
that runs all the MDAnalysis code. It takes care of starting the async
kernel, stopping it and communicating with it. It interfaces with the
CommHandler on the kernel side for messaging.
Parameters
----------
sm: :class:`~mdadash.backend.state.manager.StateManager`
Instance of the state manager
sio: :class:`socketio.AsyncServer`
Instance of the socket.io server
"""
def __init__(self, sm: StateManager, sio: socketio.AsyncServer):
self.sm = sm
self.sio = sio
self.km = AsyncKernelManager(kernel_name="python3")
self.kc = None
self._pending_futures = {}
self._is_running = False
self.comm_id = uuid.uuid4().hex
self.listen_task = None
[docs]
async def start(self) -> None:
"""Start the async kernel"""
# start the kernel
await self.km.start_kernel()
# create a client
self.kc = self.km.client()
self.kc.start_channels()
await self.kc.wait_for_ready()
# create task to listen on iopub and shell channels
self.listen_task = asyncio.create_task(self._start_listening())
# initialize the kernel core
self.kc.execute("from mdadash.backend.kernel import core")
# open comms with the kernel
self._comm_open()
self._is_running = True
# initialize n universes in kernel universe manager
await self.send_message(
"init_n_universes", {"n": len(self.sm.universe_configs)}
)
[docs]
async def stop(self) -> None:
"""Stop the async kernel"""
self._is_running = False
# wait for listen task to completely exit
await self.listen_task
self.kc.stop_channels()
self.kc = None
# shutdown kernel gracefully
await self.km.shutdown_kernel(now=False)
await self.km.cleanup_resources()
async def _start_listening(self):
"""Internal: Create separate listen tasks for iopub and shell"""
await asyncio.gather(
self._listen_iopub_channel(),
self._listen_shell_channel(),
)
async def _emit_tsdata(self, tsinfo):
"""Internal: Emit timestep data"""
tsdata = tsinfo["tsdata"]
step = tsdata.get("step", None)
total_steps = self.sm.universe_configs[0].get("total_steps", None)
done = (step / total_steps) * 100 if step and total_steps else None
timestep_info = {
"frame": tsinfo.get("frame", None),
"time": tsdata.get("time", None),
"step": step,
"done": done,
"energies": {
"temperature": tsdata.get("temperature", None),
"total_energy": tsdata.get("total_energy", None),
"potential_energy": tsdata.get("potential_energy", None),
"van_der_walls_energy": tsdata.get("van_der_walls_energy", None),
"coulomb_energy": tsdata.get("coulomb_energy", None),
"bonds_energy": tsdata.get("bonds_energy", None),
"angles_energy": tsdata.get("angles_energy", None),
"dihedrals_energy": tsdata.get("dihedrals_energy", None),
"improper_dihedrals_energy": tsdata.get(
"improper_dihedrals_energy", None
),
},
}
await self.sio.emit("timestepInfo", timestep_info)
# pylint: disable=too-many-branches
async def _listen_iopub_channel(self):
"""Internal: Listen on iopub channel"""
while self._is_running:
try:
msg = await self.kc.iopub_channel.get_msg(timeout=0.1)
msg_type = msg["header"]["msg_type"]
content = msg["content"]
parent_id = msg.get("parent_header", {}).get("msg_id")
# check if a pending future can be resolved with msg
resolve_future = False
if parent_id and parent_id in self._pending_futures:
future = self._pending_futures[parent_id]
if not future.done():
resolve_future = True
# handle different msg_type's
if msg_type == "comm_msg":
data = msg["content"]["data"]
if "tsinfo" in data:
await self._emit_tsdata(data["tsinfo"])
elif resolve_future:
future.set_result(data)
continue
elif msg_type == "stream":
if resolve_future:
future.set_result(msg["content"]["text"])
continue
# redirect kernel stdout and stderr to this server output
if content["name"] == "stdout" or content["name"] == "stderr":
output = content["text"]
file = sys.stdout if content["name"] == "stdout" else sys.stderr
print(
f"KERNEL ({content['name']}): {output}", end="", file=file
)
elif msg_type == "error":
# redirect kernel errors to server output
print(f"KERNEL (error): {content['ename']}: {content['evalue']}")
if resolve_future:
future.set_result(content["evalue"])
continue
else:
logger.debug("IOPUB: %s", msg)
# TODO: handle other message types
except (asyncio.TimeoutError, queue.Empty):
continue
async def _listen_shell_channel(self):
"""Internal: Listen on shell channel"""
while self._is_running:
try:
msg = await self.kc.shell_channel.get_msg(timeout=0.1)
# msg_type = msg["header"]["msg_type"]
# content = msg["content"]
logger.debug("SHELL: %s", msg)
except (asyncio.TimeoutError, queue.Empty):
continue
def _comm_open(self):
"""Internal: Open comms with the kernel"""
content = {
"comm_id": self.comm_id,
"target_name": "kernel_comm_handler",
"data": {"msg_type": "handshake"},
}
open_msg = self.kc.session.msg("comm_open", content=content)
self.kc.shell_channel.send(open_msg)
[docs]
async def send_message(self, msg_type: str, data: dict) -> None:
"""Send message to kernel and don't await a response
Parameters
----------
msg_type: str
A message type string that the kernel has a handler registered for
data: dict
Dict that gets passed to the handler in the kernel
"""
content = {
"comm_id": self.comm_id,
"target_name": "kernel_comm_handler",
"data": {"msg_type": msg_type, "data": data},
}
data_msg = self.kc.session.msg("comm_msg", content=content)
self.kc.shell_channel.send(data_msg)
[docs]
async def send_message_await_response(
self, msg_type: str, data: dict = None, timeout: int = 5
) -> dict | None:
"""Send message to kernel and wait for a response (async)
Parameters
----------
msg_type: str
A message type string that the kernel has a handler registered for
data: dict
Dict that gets passed to the handler in the kernel (default: None)
timeout: int
Timeout in seconds (default: 5)
Returns
-------
response: dict
Response dict indicating status. This has the following keys:
status
String indication status: 'ok' or 'error'
message
An error message string when status is 'error'
"""
content = {
"comm_id": self.comm_id,
"target_name": "kernel_comm_handler",
"data": {"msg_type": msg_type, "data": data},
}
data_msg = self.kc.session.msg("comm_msg", content=content)
msg_id = data_msg["header"]["msg_id"]
# add to the _pending_futures to that it gets resolved when the
# response arrives on the iopub channel
future = asyncio.get_running_loop().create_future()
self._pending_futures[msg_id] = future
self.kc.shell_channel.send(data_msg)
try:
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError as e: # pragma: no cover
raise TimeoutError("Timed out waiting for kernel response") from e
finally:
self._pending_futures.pop(msg_id, None)
[docs]
async def execute_code(self, code: str, timeout: int = 5) -> str:
"""Execute code in the kernel
Parameters
----------
code: str
Code to execute in the kernel
timeout: int
Timeout in seconds (default: 5)
Returns
-------
response: str
A string representation of the output of the code executed
"""
msg_id = self.kc.execute(code)
future = asyncio.get_running_loop().create_future()
self._pending_futures[msg_id] = future
try:
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError as e: # pragma: no cover
raise TimeoutError("Timed out waiting for kernel execute response") from e
finally:
self._pending_futures.pop(msg_id, None)
[docs]
async def connect_to_simulations(self) -> dict:
"""Connect to the MD simulation
Returns
-------
response: dict
Response dict indicating status. This has the following keys:
status
String indication status: 'ok' or 'error'
message
An error message string when status is 'error'
"""
response = await self.send_message_await_response(
"connect_to_simulations", self.sm.universe_configs
)
if response["status"] == "ok":
self.sm.running_state["connected"] = True
self.sm.running_state["running"] = False
return response
[docs]
async def disconnect_from_simulations(self) -> dict:
"""Disconnect from the MD simulation
Returns
-------
response: dict
Response dict indicating status. This has the following keys:
status
String indication status: 'ok' or 'error'
message
An error message string when status is 'error'
"""
response = await self.send_message_await_response(
"disconnect_from_simulations", {}
)
if response["status"] == "ok":
self.sm.running_state["connected"] = False
return response
[docs]
async def pause_simulations(self) -> dict:
"""Pause MD simulations
Returns
-------
response: dict
Response dict indicating status. This has the following keys:
status
String indication status: 'ok' or 'error'
message
An error message string when status is 'error'
"""
response = await self.send_message_await_response("pause_simulations", {})
if response["status"] == "ok":
self.sm.running_state["running"] = False
return response
[docs]
async def resume_simulations(self) -> dict:
"""Resume MD simulations
Returns
-------
response: dict
Response dict indicating status. This has the following keys:
status
String indication status: 'ok' or 'error'
message
An error message string when status is 'error'
"""
response = await self.send_message_await_response("resume_simulations", {})
if response["status"] == "ok":
self.sm.running_state["running"] = True
return response