"""Utility to reach a TCP server behind an SSH jump host.
This script establishes an SSH tunnel to a remote host and forwards a
local port to a TCP service that is only reachable from that host. Once
the tunnel is up, you can interact with the remote service using the
local forwarding port.
Example
-------
+python ssh_tunnel_client.py \
+ --ssh-host jump.example.com --ssh-user alice --ssh-key ~/.ssh/id_rsa \
+ --remote-host 10.0.0.5 --remote-port 5012
+
+The script will open localhost:5012 by default; any TCP client pointed to
+that address/port pair will effectively communicate with 10.0.0.5:5012
+through the SSH tunnel.
+"""
+from __future__ import annotations
+
+import argparse
+import getpass
+import logging
+import socketserver
+import sys
+from typing import Optional, Tuple
+
+import paramiko
+
+
+LOGGER = logging.getLogger(__name__)
+
+
+class ForwardServer(socketserver.ThreadingTCPServer):
+ daemon_threads = True
+ allow_reuse_address = True
+
+
+class TunnelHandler(socketserver.BaseRequestHandler):
+ """Forwards a local TCP connection through the SSH transport."""
+
+ def handle(self) -> None: # noqa: D401 - see class docstring.
+ transport: paramiko.Transport = self.server.transport # type: ignore[attr-defined]
+ remote_host: str = self.server.remote_host # type: ignore[attr-defined]
+ remote_port: int = self.server.remote_port # type: ignore[attr-defined]
+
+ try:
+ chan = transport.open_channel(
+ "direct-tcpip",
+ (remote_host, remote_port),
+ self.request.getsockname(),
+ )
+ except Exception as exc: # pragma: no cover - network specific
+ LOGGER.error("Failed to open SSH channel: %s", exc)
+ return
+
+ if chan is None:
+ LOGGER.error("SSH channel creation returned None")
+ return
+
+ LOGGER.info(
+ "Forwarding connection from %s to %s:%s",
+ self.client_address,
+ remote_host,
+ remote_port,
+ )
+
+ try:
+ while True:
+ rdata = self.request.recv(1024)
+ if len(rdata) == 0:
+ break
+ chan.sendall(rdata)
+
+ response = chan.recv(1024)
+ if len(response) == 0:
+ break
+ self.request.sendall(response)
+ finally:
+ chan.close()
+ self.request.close()
+
+
+def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--ssh-host", required=True, help="SSH jump host")
+ parser.add_argument("--ssh-port", type=int, default=22, help="SSH port")
+ parser.add_argument("--ssh-user", required=True, help="SSH username")
+ parser.add_argument(
+ "--ssh-key",
+ help="Path to private key for authentication (optional if password is provided)",
+ )
+ parser.add_argument(
+ "--ssh-password",
+ help="Password for SSH authentication (optional if key is provided)",
+ )
+ parser.add_argument(
+ "--remote-host",
+ required=True,
+ help="Destination host reachable from the SSH server",
+ )
+ parser.add_argument(
+ "--remote-port",
+ type=int,
+ default=5012,
+ help="Destination TCP port on the remote host",
+ )
+ parser.add_argument(
+ "--local-port",
+ type=int,
+ default=5012,
+ help="Local port for the forwarded connection",
+ )
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ help="Enable debug logging",
+ )
+ return parser.parse_args(argv)
+
+
+def create_ssh_client(args: argparse.Namespace) -> paramiko.SSHClient:
+ client = paramiko.SSHClient()
+ client.load_system_host_keys()
+ client.set_missing_host_key_policy(paramiko.WarningPolicy())
+ password = args.ssh_password
+ if password is None and args.ssh_key is None:
+ password = getpass.getpass("SSH password: ")
+
+ client.connect(
+ args.ssh_host,
+ port=args.ssh_port,
+ username=args.ssh_user,
+ key_filename=args.ssh_key,
+ password=password,
+ look_for_keys=args.ssh_key is None and password is None,
+ )
+ return client
+
+
+def start_forwarding(
+ client: paramiko.SSHClient,
+ remote_host: str,
+ remote_port: int,
+ local_port: int,
+) -> Tuple[ForwardServer, Tuple[str, int]]:
+ transport = client.get_transport()
+ if transport is None or not transport.is_active():
+ raise RuntimeError("SSH transport is not available")
+
+ server = ForwardServer(("127.0.0.1", local_port), TunnelHandler)
+ server.transport = transport # type: ignore[attr-defined]
+ server.remote_host = remote_host # type: ignore[attr-defined]
+ server.remote_port = remote_port # type: ignore[attr-defined]
+ return server, server.server_address
+
+
+def main(argv: Optional[list[str]] = None) -> int:
+ args = parse_args(argv)
+
+ logging.basicConfig(
+ level=logging.DEBUG if args.verbose else logging.INFO,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+ )
+
+ try:
+ client = create_ssh_client(args)
+ except paramiko.AuthenticationException as exc:
+ LOGGER.error("Authentication failed: %s", exc)
+ return 1
+ except paramiko.SSHException as exc:
+ LOGGER.error("Unable to establish SSH connection: %s", exc)
+ return 1
+
+ LOGGER.info(
+ "Connected to %s. Forwarding localhost:%d to %s:%d",
+ args.ssh_host,
+ args.local_port,
+ args.remote_host,
+ args.remote_port,
+ )
+
+ try:
+ server, local_address = start_forwarding(
+ client,
+ args.remote_host,
+ args.remote_port,
+ args.local_port,
+ )
+ except Exception as exc:
+ LOGGER.error("Failed to start port forwarding: %s", exc)
+ client.close()
+ return 1
+
+ LOGGER.info("Tunnel established on %s:%d", *local_address)
+
+ try:
+ server.serve_forever()
+ except KeyboardInterrupt:
+ LOGGER.info("Interrupted by user, shutting down")
+ finally:
+ server.server_close()
+ client.close()
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
No comments:
Post a Comment