Source code for datalad_next.url_operations.ssh

"""Handler for operations, such as "download", on ssh:// URLs"""

# allow for |-type UnionType declarations
from __future__ import annotations

import logging
import subprocess
import sys
from itertools import chain
from pathlib import (
    Path,
    PurePosixPath,
)
from queue import (
    Full,
    Queue,
)
from typing import (
    Any,
    Dict,
    Generator,
    IO,
)
from urllib.parse import urlparse

from datalad_next.runners import (
    GeneratorMixIn,
    NoCaptureGeneratorProtocol,
    Protocol as RunnerProtocol,
    StdOutCaptureGeneratorProtocol,
    ThreadedRunner,
    CommandError,
)

from datalad_next.utils.consts import COPY_BUFSIZE

from . import (
    UrlOperations,
    UrlOperationsRemoteError,
    UrlOperationsResourceUnknown,
)

lgr = logging.getLogger('datalad.ext.next.ssh_url_operations')


__all__ = ['SshUrlOperations']


[docs]class SshUrlOperations(UrlOperations): """Handler for operations on ``ssh://`` URLs For downloading files, only servers that support execution of the commands 'printf', 'ls -nl', 'awk', and 'cat' are supported. This includes a wide range of operating systems, including devices that provide these commands via the 'busybox' software. .. note:: The present implementation does not support SSH connection multiplexing, (re-)authentication is performed for each request. This limitation is likely to be removed in the future, and connection multiplexing supported where possible (non-Windows platforms). """ # first try ls'ing the path, and catch a missing path with a dedicated 244 # exit code, to be able to distinguish the original exit=2 that ls-call # from a later exit=2 from awk in case of a "fatal error". # when executed through ssh, only a missing file would yield 244, while # a connection error or other problem unrelated to the present of a file # would a different error code (255 in case of a connection error) _stat_cmd = "printf \"\1\2\3\"; ls '{fpath}' &> /dev/null " \ "&& ls -nl '{fpath}' | awk 'BEGIN {{ORS=\"\1\"}} {{print $5}}' " \ "|| exit 244" _cat_cmd = "cat '{fpath}'"
[docs] def stat(self, url: str, *, credential: str | None = None, timeout: float | None = None) -> Dict: """Gather information on a URL target, without downloading it See :meth:`datalad_next.url_operations.UrlOperations.stat` for parameter documentation and exception behavior. """ try: props = self._stat( url, cmd=SshUrlOperations._stat_cmd, ) except CommandError as e: if e.code == 244: # this is the special code for a file-not-found raise UrlOperationsResourceUnknown(url) from e else: raise UrlOperationsRemoteError(url, message=str(e)) from e return {k: v for k, v in props.items() if not k.startswith('_')}
def _stat(self, url: str, cmd: str) -> Dict: # any stream must start with this magic marker, or we do not # recognize what is happening # after this marker, the server will send the size of the # to-be-downloaded file in bytes, followed by another magic # b'\1', and the file content after that need_magic = b'\1\2\3' expected_size_str = b'' expected_size = None ssh_cat = _SshCat(url) stream = ssh_cat.run(cmd, protocol=StdOutCaptureGeneratorProtocol) for chunk in stream: if need_magic: expected_magic = need_magic[:min(len(need_magic), len(chunk))] incoming_magic = chunk[:len(need_magic)] # does the incoming data have the remaining magic bytes? if incoming_magic != expected_magic: raise RuntimeError( "Protocol error: report header not received") # reduce (still missing) magic, if any need_magic = need_magic[len(expected_magic):] # strip magic from input chunk = chunk[len(expected_magic):] if chunk and expected_size is None: # we have incoming data left and # we have not yet consumed the size info size_data = chunk.split(b'\1', maxsplit=1) expected_size_str += size_data[0] if len(size_data) > 1: # this is not only size info, but we found the start of # the data expected_size = int(expected_size_str) chunk = size_data[1] if expected_size: props = { 'content-length': expected_size, '_stream': chain([chunk], stream) if chunk else stream, } return props # there should be no data left to process, or something went wrong assert not chunk
[docs] def download(self, from_url: str, to_path: Path | None, *, # unused, but theoretically could be used to # obtain escalated/different privileges on a system # to gain file access credential: str | None = None, hash: str | None = None, timeout: float | None = None) -> Dict: """Download a file by streaming it through an SSH connection. On the server-side, the file size is determined and sent. Afterwards the file content is sent via `cat` to the SSH client. See :meth:`datalad_next.url_operations.UrlOperations.download` for parameter documentation and exception behavior. """ # this is pretty much shutil.copyfileobj() with the necessary # wrapping to perform hashing and progress reporting hasher = self._get_hasher(hash) progress_id = self._get_progress_id(from_url, to_path) dst_fp = None try: props = self._stat( from_url, cmd=f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}', ) stream = props.pop('_stream') expected_size = props['content-length'] dst_fp = sys.stdout.buffer if to_path is None \ else open(to_path, 'wb') # Localize variable access to minimize overhead dst_fp_write = dst_fp.write # download can start self._progress_report_start( progress_id, ('Download %s to %s', from_url, to_path), 'downloading', expected_size, ) for chunk in stream: # write data dst_fp_write(chunk) # compute hash simultaneously hasher.update(chunk) self._progress_report_update( progress_id, ('Downloaded chunk',), len(chunk)) props.update(hasher.get_hexdigest()) return props except CommandError as e: if e.code == 244: # this is the special code for a file-not-found raise UrlOperationsResourceUnknown(from_url) from e else: # wrap this into the datalad-standard, but keep the # original exception linked raise UrlOperationsRemoteError(from_url, message=str(e)) from e finally: if dst_fp and to_path is not None: dst_fp.close() self._progress_report_stop(progress_id, ('Finished download',))
[docs] def upload(self, from_path: Path | None, to_url: str, *, credential: str | None = None, hash: list[str] | None = None, timeout: float | None = None) -> Dict: """Upload a file by streaming it through an SSH connection. It, more or less, runs `ssh <host> 'cat > <path>'`. See :meth:`datalad_next.url_operations.UrlOperations.upload` for parameter documentation and exception behavior. """ if from_path is None: source_name = '<STDIN>' return self._perform_upload( src_fp=sys.stdin.buffer, source_name=source_name, to_url=to_url, hash_names=hash, expected_size=None, timeout=timeout, ) else: # die right away, if we lack read permissions or there is no file with from_path.open("rb") as src_fp: return self._perform_upload( src_fp=src_fp, source_name=from_path, to_url=to_url, hash_names=hash, expected_size=from_path.stat().st_size, timeout=timeout, )
def _perform_upload(self, src_fp: IO, source_name: str, to_url: str, hash_names: list[str] | None, expected_size: int | None, timeout: int | None) -> dict: hasher = self._get_hasher(hash_names) # we limit the queue to few items in order to `make queue.put()` # block relatively quickly, and thereby have the progress report # actually track the upload, and not just the feeding of the # queue upload_queue = Queue(maxsize=2) ssh_cat = _SshCat(to_url) ssh_runner_generator = ssh_cat.run( # leave special exit code when writing fails, but not the # general SSH access "( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244", protocol=NoCaptureGeneratorProtocol, stdin=upload_queue, timeout=timeout, ) # file is open, we can start progress tracking progress_id = self._get_progress_id(source_name, to_url) self._progress_report_start( progress_id, ('Upload %s to %s', source_name, to_url), 'uploading', expected_size, ) try: upload_size = 0 while ssh_runner_generator.runner.process.poll() is None: chunk = src_fp.read(COPY_BUFSIZE) if chunk == b'': break chunk_size = len(chunk) # compute hash simultaneously hasher.update(chunk) # we are just putting stuff in the queue, and rely on # its maxsize to cause it to block the next call to # have the progress reports be anyhow valid upload_queue.put(chunk, timeout=timeout) self._progress_report_update( progress_id, ('Uploaded chunk',), chunk_size) upload_size += chunk_size # we're done, close queue upload_queue.put(None, timeout=timeout) # Exhaust the generator, that might raise CommandError # or TimeoutError, if timeout was not `None`. tuple(ssh_runner_generator) except CommandError as e: if e.code == 244: raise UrlOperationsResourceUnknown(to_url) from e else: raise UrlOperationsRemoteError(to_url, message=str(e)) from e except (TimeoutError, Full): ssh_runner_generator.runner.process.kill() raise TimeoutError finally: self._progress_report_stop(progress_id, ('Finished upload',)) assert ssh_runner_generator.return_code == 0, "Unexpected ssh " \ f"return value: {ssh_runner_generator.return_code}" return { **hasher.get_hexdigest(), # return how much was copied. we could compare with # `expected_size` and error on mismatch, but not all # sources can provide that (e.g. stdin) 'content-length': upload_size }
class _SshCat: def __init__(self, url: str, *additional_ssh_args): self._parsed = urlparse(url) # make sure the essential pieces exist assert self._parsed.hostname assert self._parsed.path self.ssh_args: list[str] = list(additional_ssh_args) def run(self, payload_cmd: str, protocol: type[RunnerProtocol], stdin: Queue | None = None, timeout: float | None = None) -> Any | Generator: fpath = self._parsed.path cmd = ['ssh'] cmd.extend(self.ssh_args) cmd.extend([ '-e', 'none', self._parsed.hostname, payload_cmd.format( fdir=str(PurePosixPath(fpath).parent), fpath=fpath, ), ]) return ThreadedRunner( cmd=cmd, protocol_class=protocol, stdin=subprocess.DEVNULL if stdin is None else stdin, timeout=timeout, ).run()