Source code for pymt5.transport

import asyncio
import contextlib
import enum
import inspect as _inspect
import struct
import time
import traceback
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any

import websockets
from websockets.asyncio.client import ClientConnection

from pymt5._logging import get_logger
from pymt5._metrics import MetricsCollector
from pymt5._rate_limiter import TokenBucketRateLimiter
from pymt5.constants import CMD_BOOTSTRAP, DEFAULT_COMMAND_TIMEOUT, DEFAULT_TOKEN_LENGTH, VALID_COMMANDS
from pymt5.crypto import AESCipher, initial_cipher
from pymt5.exceptions import MT5ConnectionError, MT5TimeoutError, ProtocolError, SessionError
from pymt5.protocol import ResponseFrame, build_command, pack_outer, parse_response_frame, unpack_outer

logger = get_logger("pymt5.transport")

# Cache inspect.signature result at module level (Phase 3.5)
_WS_CONNECT_HAS_PROXY = "proxy" in _inspect.signature(websockets.connect).parameters


[docs] class TransportState(enum.Enum): """Connection lifecycle states for the WebSocket transport.""" DISCONNECTED = "disconnected" CONNECTING = "connecting" READY = "ready" CLOSING = "closing" ERROR = "error"
@dataclass(slots=True) class CommandResult: command: int code: int body: bytes class MT5WebSocketTransport: def __init__( self, uri: str, timeout: float = DEFAULT_COMMAND_TIMEOUT, rate_limit: float = 0, rate_burst: int = 20, metrics: MetricsCollector | None = None, ): self.uri = uri self.timeout = timeout self.ws: ClientConnection | None = None self._state = TransportState.DISCONNECTED self.token = bytes(DEFAULT_TOKEN_LENGTH) self.cipher: AESCipher = initial_cipher() self._recv_task: asyncio.Task[None] | None = None self._lock = asyncio.Lock() self._rate_limiter = TokenBucketRateLimiter(rate=rate_limit, burst=rate_burst) self._pending: dict[int, deque[asyncio.Future[CommandResult]]] = defaultdict(deque) self._listeners: dict[int, set[Callable[[CommandResult], Awaitable[None] | None]]] = defaultdict(set) self._on_disconnect: Callable[[], None] | None = None self._shutdown_event = asyncio.Event() self._disconnect_lock = asyncio.Lock() self._metrics = metrics self._last_message_at: float = 0.0 self._connected_at: float = 0.0 self._callback_error_handlers: list[Callable] = [] self._server_build: int = 0 @property def state(self) -> TransportState: """Current transport connection state.""" return self._state @property def is_ready(self) -> bool: """Whether the transport is ready to send commands.""" return self._state == TransportState.READY @is_ready.setter def is_ready(self, value: bool) -> None: """Backward-compatible setter for is_ready.""" self._state = TransportState.READY if value else TransportState.DISCONNECTED @property def server_build(self) -> int: """Server build number extracted from the bootstrap response prefix.""" return self._server_build async def connect(self) -> None: # Guard against double-connect (Phase 2.4) if self.ws is not None: await self.close() self._state = TransportState.CONNECTING self._shutdown_event.clear() self.cipher = initial_cipher() logger.info("connecting to %s", self.uri) connect_kwargs: dict[str, Any] = { "ping_interval": None, "max_size": None, "open_timeout": self.timeout, "additional_headers": { "Origin": "https://web.metatrader.app", }, } # websockets >=15 auto-detects system proxy (proxy=True default) # which breaks the MT5 binary protocol; bypass it explicitly. if _WS_CONNECT_HAS_PROXY: connect_kwargs["proxy"] = None self.ws = await asyncio.wait_for( websockets.connect(self.uri, **connect_kwargs), timeout=self.timeout, ) self._recv_task = asyncio.create_task(self._recv_loop()) logger.debug("websocket open, sending bootstrap") bootstrap = await self._send_raw(CMD_BOOTSTRAP, self.token, check_ready=False) if bootstrap.code != 0: raise MT5ConnectionError(f"bootstrap failed: code={bootstrap.code}") if len(bootstrap.body) < 66: raise MT5ConnectionError(f"bootstrap response too short: {len(bootstrap.body)}") self.token = bootstrap.body[2:66] self.cipher = AESCipher(bootstrap.body[66:]) # Try to extract server build from the 2-byte prefix (U16 LE) try: self._server_build = struct.unpack_from("<H", bootstrap.body, 0)[0] except struct.error: self._server_build = 0 self._state = TransportState.READY self._connected_at = time.monotonic() if self._metrics: self._metrics.on_connect() logger.info("transport ready (key exchanged)") async def close(self) -> None: logger.info("closing transport") async with self._disconnect_lock: self._state = TransportState.CLOSING self._shutdown_event.set() if self._recv_task is not None: self._recv_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._recv_task self._recv_task = None if self.ws is not None: await self.ws.close() self.ws = None self._fail_all(SessionError("transport closed")) self._state = TransportState.DISCONNECTED logger.debug("transport closed") def on(self, command: int, callback: Callable[[CommandResult], Awaitable[None] | None]) -> None: self._listeners[command].add(callback) def off(self, command: int, callback: Callable[[CommandResult], Awaitable[None] | None] | None = None) -> None: if callback is None: self._listeners[command].clear() return self._listeners[command].discard(callback) async def send_command(self, command: int, payload: bytes | None = None) -> CommandResult: return await self._send_raw(command, payload or b"", check_ready=True) async def _send_raw(self, command: int, payload: bytes, check_ready: bool) -> CommandResult: if command not in VALID_COMMANDS: raise ProtocolError(f"unsupported command: {command}") if check_ready and not self.is_ready: raise SessionError(f"transport not ready for command {command}") if self.ws is None: raise MT5ConnectionError("websocket not connected") await self._rate_limiter.acquire() async with self._lock: future: asyncio.Future[CommandResult] = asyncio.get_running_loop().create_future() self._pending[command].append(future) inner = build_command(command, payload) encrypted = self.cipher.encrypt(inner) logger.debug("send cmd=%d payload=%d bytes", command, len(payload)) await self.ws.send(pack_outer(encrypted)) if self._metrics: self._metrics.on_command_sent(command) try: return await asyncio.wait_for(future, timeout=self.timeout) except TimeoutError: # Remove leaked future from _pending on timeout (Phase 2.1) if not future.done(): queue = self._pending.get(command) if queue: try: queue.remove(future) except ValueError: pass raise MT5TimeoutError(f"command {command} timed out after {self.timeout}s") from None async def _recv_loop(self) -> None: try: assert self.ws is not None async for message in self.ws: if isinstance(message, str): continue try: raw = message if isinstance(message, bytes) else bytes(message) _, _, encrypted = unpack_outer(raw) decrypted = self.cipher.decrypt(encrypted) frame = parse_response_frame(decrypted) self._last_message_at = time.monotonic() logger.debug("recv cmd=%d code=%d body=%d bytes", frame.command, frame.code, len(frame.body)) await self._dispatch(frame) except (struct.error, ValueError, TypeError, IndexError, ProtocolError) as exc: logger.error("recv_loop parse error: %s", exc) continue except asyncio.CancelledError: raise except (OSError, websockets.exceptions.WebSocketException) as exc: logger.error("recv_loop disconnected: %s", exc) self._fail_all(exc) self._state = TransportState.ERROR if self._metrics: self._metrics.on_disconnect(str(exc)) # Serialize against close() to prevent double-disconnect should_notify = False try: async with self._disconnect_lock: if self._on_disconnect and not self._shutdown_event.is_set(): should_notify = True except asyncio.CancelledError: raise if should_notify and self._on_disconnect: self._on_disconnect() async def _dispatch(self, frame: ResponseFrame) -> None: result = CommandResult(command=frame.command, code=frame.code, body=frame.body) if self._metrics: self._metrics.on_command_received(frame.command, frame.code) queue = self._pending.get(frame.command) if queue: while queue: future = queue.popleft() if not future.done(): future.set_result(result) break for callback in tuple(self._listeners.get(frame.command, ())): try: maybe = callback(result) if _inspect.isawaitable(maybe): await maybe except Exception as exc: logger.error( "callback %s raised %s:\n%s", getattr(callback, "__name__", repr(callback)), exc, traceback.format_exc(), ) for error_handler in tuple(self._callback_error_handlers): try: error_handler(exc, callback) except Exception: logger.warning( "callback error handler %s itself raised", getattr(error_handler, "__name__", repr(error_handler)), exc_info=True, ) def _fail_all(self, exc: Exception) -> None: for queue in self._pending.values(): while queue: future = queue.popleft() if not future.done(): future.set_exception(exc)