main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. #!/usr/bin/env python
  2. # License: GPLv3 Copyright: 2020, Kovid Goyal <kovid at kovidgoyal.net>
  3. import json
  4. import os
  5. import shlex
  6. import shutil
  7. import subprocess
  8. import sys
  9. import tempfile
  10. import time
  11. from typing import Any, List, Optional
  12. from kitty.cli import parse_args
  13. from kitty.cli_stub import RemoteFileCLIOptions
  14. from kitty.constants import cache_dir
  15. from kitty.typing import BossType
  16. from kitty.utils import SSHConnectionData, command_for_open, get_editor, open_cmd
  17. from ..tui.handler import result_handler
  18. from ..tui.operations import faint, raw_mode, reset_terminal, styled
  19. from ..tui.utils import get_key_press
  20. is_ssh_kitten_sentinel = '!#*&$#($ssh-kitten)(##$'
  21. def key(x: str) -> str:
  22. return styled(x, bold=True, fg='green')
  23. def option_text() -> str:
  24. return '''\
  25. --mode -m
  26. choices=ask,edit
  27. default=ask
  28. Which mode to operate in.
  29. --path -p
  30. Path to the remote file.
  31. --hostname -h
  32. Hostname of the remote host.
  33. --ssh-connection-data
  34. The data used to connect over ssh.
  35. '''
  36. def show_error(msg: str) -> None:
  37. print(styled(msg, fg='red'), file=sys.stderr)
  38. print()
  39. print('Press any key to quit', flush=True)
  40. with raw_mode():
  41. while True:
  42. try:
  43. q = sys.stdin.buffer.read(1)
  44. if q:
  45. break
  46. except (KeyboardInterrupt, EOFError):
  47. break
  48. def ask_action(opts: RemoteFileCLIOptions) -> str:
  49. print('What would you like to do with the remote file on {}:'.format(styled(opts.hostname or 'unknown', bold=True, fg='magenta')))
  50. print(styled(opts.path or '', fg='yellow', fg_intense=True))
  51. print()
  52. def help_text(x: str) -> str:
  53. return faint(x)
  54. print('{}dit the file'.format(key('E')))
  55. print(help_text('The file will be downloaded and opened in an editor. Any changes you save will'
  56. ' be automatically sent back to the remote machine'))
  57. print()
  58. print('{}pen the file'.format(key('O')))
  59. print(help_text('The file will be downloaded and opened by the default open program'))
  60. print()
  61. print('{}ave the file'.format(key('S')))
  62. print(help_text('The file will be downloaded to a destination you select'))
  63. print()
  64. print('{}ancel'.format(key('C')))
  65. print()
  66. sys.stdout.flush()
  67. response = get_key_press('ceos', 'c')
  68. return {'e': 'edit', 'o': 'open', 's': 'save'}.get(response, 'cancel')
  69. def hostname_matches(from_hyperlink: str, actual: str) -> bool:
  70. if from_hyperlink == actual:
  71. return True
  72. if from_hyperlink.partition('.')[0] == actual.partition('.')[0]:
  73. return True
  74. return False
  75. class ControlMaster:
  76. def __init__(self, conn_data: SSHConnectionData, remote_path: str, cli_opts: RemoteFileCLIOptions, dest: str = ''):
  77. self.conn_data = conn_data
  78. self.cli_opts = cli_opts
  79. self.remote_path = remote_path
  80. self.dest = dest
  81. self.tdir = ''
  82. self.last_error_log = ''
  83. self.cmd_prefix = cmd = [
  84. conn_data.binary, '-o', f'ControlPath=~/.ssh/kitty-rf-{os.getpid()}-%C',
  85. '-o', 'TCPKeepAlive=yes', '-o', 'ControlPersist=yes'
  86. ]
  87. self.is_ssh_kitten = conn_data.binary is is_ssh_kitten_sentinel
  88. if self.is_ssh_kitten:
  89. del cmd[:]
  90. self.batch_cmd_prefix = cmd
  91. sk_cmdline = json.loads(conn_data.identity_file)
  92. while '-t' in sk_cmdline:
  93. sk_cmdline.remove('-t')
  94. cmd.extend(sk_cmdline[:-2])
  95. else:
  96. if conn_data.port:
  97. cmd.extend(['-p', str(conn_data.port)])
  98. if conn_data.identity_file:
  99. cmd.extend(['-i', conn_data.identity_file])
  100. self.batch_cmd_prefix = cmd + ['-o', 'BatchMode=yes']
  101. def check_call(self, cmd: List[str]) -> None:
  102. p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, stdin=subprocess.DEVNULL)
  103. stdout = p.communicate()[0]
  104. if p.wait() != 0:
  105. out = stdout.decode('utf-8', 'replace')
  106. raise Exception(f'The ssh command: {shlex.join(cmd)} failed with exit code {p.returncode} and output: {out}')
  107. def __enter__(self) -> 'ControlMaster':
  108. if not self.is_ssh_kitten:
  109. self.check_call(
  110. self.cmd_prefix + ['-o', 'ControlMaster=auto', '-fN', self.conn_data.hostname])
  111. self.check_call(
  112. self.batch_cmd_prefix + ['-O', 'check', self.conn_data.hostname])
  113. if not self.dest:
  114. self.tdir = tempfile.mkdtemp()
  115. self.dest = os.path.join(self.tdir, os.path.basename(self.remote_path))
  116. return self
  117. def __exit__(self, *a: Any) -> None:
  118. if not self.is_ssh_kitten:
  119. subprocess.Popen(
  120. self.batch_cmd_prefix + ['-O', 'exit', self.conn_data.hostname],
  121. stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL
  122. ).wait()
  123. if self.tdir:
  124. shutil.rmtree(self.tdir)
  125. @property
  126. def is_alive(self) -> bool:
  127. if self.is_ssh_kitten:
  128. return True
  129. return subprocess.Popen(
  130. self.batch_cmd_prefix + ['-O', 'check', self.conn_data.hostname],
  131. stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL
  132. ).wait() == 0
  133. def check_hostname_matches(self) -> bool:
  134. if self.is_ssh_kitten:
  135. return True
  136. cp = subprocess.run(self.batch_cmd_prefix + [self.conn_data.hostname, 'hostname', '-f'], stdout=subprocess.PIPE,
  137. stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL)
  138. if cp.returncode == 0:
  139. q = tuple(filter(None, cp.stdout.decode('utf-8').strip().splitlines()))[-1]
  140. if not hostname_matches(self.cli_opts.hostname or '', q):
  141. print(reset_terminal(), end='')
  142. print(f'The remote hostname {styled(q, fg="green")} does not match the')
  143. print(f'hostname in the hyperlink {styled(self.cli_opts.hostname or "", fg="red")}')
  144. print('This indicates that kitty has not connected to the correct remote machine.')
  145. print('This can happen, for example, when using nested SSH sessions.')
  146. print(f'The hostname kitty used to connect was: {styled(self.conn_data.hostname, fg="yellow")}', end='')
  147. if self.conn_data.port is not None:
  148. print(f' with port: {self.conn_data.port}')
  149. print()
  150. print()
  151. print('Do you want to continue anyway?')
  152. print(
  153. f'{styled("Y", fg="green")}es',
  154. f'{styled("N", fg="red")}o', sep='\t'
  155. )
  156. sys.stdout.flush()
  157. response = get_key_press('yn', 'n')
  158. print(reset_terminal(), end='')
  159. return response == 'y'
  160. return True
  161. def show_error(self, msg: str) -> None:
  162. if self.last_error_log:
  163. print(self.last_error_log, file=sys.stderr)
  164. self.last_error_log = ''
  165. show_error(msg)
  166. def download(self) -> bool:
  167. cmdline = self.batch_cmd_prefix + [self.conn_data.hostname, 'cat', shlex.quote(self.remote_path)]
  168. with open(self.dest, 'wb') as f:
  169. cp = subprocess.run(cmdline, stdout=f, stderr=subprocess.PIPE, stdin=subprocess.DEVNULL)
  170. if cp.returncode != 0:
  171. self.last_error_log = f'The command: {shlex.join(cmdline)} failed\n' + cp.stderr.decode()
  172. return False
  173. return True
  174. def upload(self, suppress_output: bool = True) -> bool:
  175. cmd_prefix = self.cmd_prefix if suppress_output else self.batch_cmd_prefix
  176. cmd = cmd_prefix + [self.conn_data.hostname, 'cat', '>', shlex.quote(self.remote_path)]
  177. if not suppress_output:
  178. print(shlex.join(cmd))
  179. with open(self.dest, 'rb') as f:
  180. if suppress_output:
  181. cp = subprocess.run(cmd, stdin=f, capture_output=True)
  182. if cp.returncode == 0:
  183. return True
  184. self.last_error_log = f'The command: {shlex.join(cmd)} failed\n' + cp.stdout.decode()
  185. else:
  186. return subprocess.run(cmd, stdin=f).returncode == 0
  187. return False
  188. Result = Optional[str]
  189. def main(args: List[str]) -> Result:
  190. msg = 'Ask the user what to do with the remote file. For internal use by kitty, do not run it directly.'
  191. try:
  192. cli_opts, items = parse_args(args[1:], option_text, '', msg, 'kitty +kitten remote_file', result_class=RemoteFileCLIOptions)
  193. except SystemExit as e:
  194. if e.code != 0:
  195. print(e.args[0])
  196. input('Press Enter to quit')
  197. raise SystemExit(e.code)
  198. try:
  199. action = ask_action(cli_opts)
  200. finally:
  201. print(reset_terminal(), end='', flush=True)
  202. try:
  203. return handle_action(action, cli_opts)
  204. except Exception:
  205. print(reset_terminal(), end='', flush=True)
  206. import traceback
  207. traceback.print_exc()
  208. show_error('Failed with unhandled exception')
  209. return None
  210. def save_as(conn_data: SSHConnectionData, remote_path: str, cli_opts: RemoteFileCLIOptions) -> None:
  211. ddir = cache_dir()
  212. os.makedirs(ddir, exist_ok=True)
  213. last_used_store_path = os.path.join(ddir, 'remote-file-last-used.txt')
  214. try:
  215. with open(last_used_store_path) as f:
  216. last_used_path = f.read()
  217. except FileNotFoundError:
  218. last_used_path = tempfile.gettempdir()
  219. last_used_file = os.path.join(last_used_path, os.path.basename(remote_path))
  220. print(
  221. 'Where do you want to save the file? Leaving it blank will save it as:',
  222. styled(last_used_file, fg='yellow')
  223. )
  224. print('Relative paths will be resolved from:', styled(os.getcwd(), fg_intense=True, bold=True))
  225. print()
  226. from ..tui.path_completer import get_path
  227. try:
  228. dest = get_path()
  229. except (KeyboardInterrupt, EOFError):
  230. return
  231. if dest:
  232. dest = os.path.expandvars(os.path.expanduser(dest))
  233. if os.path.isdir(dest):
  234. dest = os.path.join(dest, os.path.basename(remote_path))
  235. with open(last_used_store_path, 'w') as f:
  236. f.write(os.path.dirname(os.path.abspath(dest)))
  237. else:
  238. dest = last_used_file
  239. if os.path.exists(dest):
  240. print(reset_terminal(), end='')
  241. print(f'The file {styled(dest, fg="yellow")} already exists. What would you like to do?')
  242. print(f'{key("O")}verwrite {key("A")}bort Auto {key("R")}ename {key("N")}ew name')
  243. response = get_key_press('anor', 'a')
  244. if response == 'a':
  245. return
  246. if response == 'n':
  247. print(reset_terminal(), end='')
  248. return save_as(conn_data, remote_path, cli_opts)
  249. if response == 'r':
  250. q = dest
  251. c = 0
  252. while os.path.exists(q):
  253. c += 1
  254. b, ext = os.path.splitext(dest)
  255. q = f'{b}-{c}{ext}'
  256. dest = q
  257. if os.path.dirname(dest):
  258. os.makedirs(os.path.dirname(dest), exist_ok=True)
  259. with ControlMaster(conn_data, remote_path, cli_opts, dest=dest) as master:
  260. if master.check_hostname_matches():
  261. if not master.download():
  262. master.show_error('Failed to copy file from remote machine')
  263. def handle_action(action: str, cli_opts: RemoteFileCLIOptions) -> Result:
  264. cli_data = json.loads(cli_opts.ssh_connection_data or '')
  265. if cli_data and cli_data[0] == is_ssh_kitten_sentinel:
  266. conn_data = SSHConnectionData(is_ssh_kitten_sentinel, cli_data[-1], -1, identity_file=json.dumps(cli_data[1:]))
  267. else:
  268. conn_data = SSHConnectionData(*cli_data)
  269. remote_path = cli_opts.path or ''
  270. if action == 'open':
  271. print('Opening', cli_opts.path, 'from', cli_opts.hostname)
  272. dest = os.path.join(tempfile.mkdtemp(), os.path.basename(remote_path))
  273. with ControlMaster(conn_data, remote_path, cli_opts, dest=dest) as master:
  274. if master.check_hostname_matches():
  275. if master.download():
  276. return dest
  277. master.show_error('Failed to copy file from remote machine')
  278. elif action == 'edit':
  279. print('Editing', cli_opts.path, 'from', cli_opts.hostname)
  280. editor = get_editor()
  281. with ControlMaster(conn_data, remote_path, cli_opts) as master:
  282. if not master.check_hostname_matches():
  283. return None
  284. if not master.download():
  285. master.show_error(f'Failed to download {remote_path}')
  286. return None
  287. mtime = os.path.getmtime(master.dest)
  288. print(reset_terminal(), end='', flush=True)
  289. editor_process = subprocess.Popen(editor + [master.dest])
  290. while editor_process.poll() is None:
  291. time.sleep(0.1)
  292. newmtime = os.path.getmtime(master.dest)
  293. if newmtime > mtime:
  294. mtime = newmtime
  295. if master.is_alive:
  296. master.upload()
  297. print(reset_terminal(), end='', flush=True)
  298. if master.is_alive:
  299. if not master.upload(suppress_output=False):
  300. master.show_error(f'Failed to upload {remote_path}')
  301. else:
  302. master.show_error(f'Failed to upload {remote_path}, SSH master process died')
  303. elif action == 'save':
  304. print('Saving', cli_opts.path, 'from', cli_opts.hostname)
  305. save_as(conn_data, remote_path, cli_opts)
  306. return None
  307. @result_handler()
  308. def handle_result(args: List[str], data: Result, target_window_id: int, boss: BossType) -> None:
  309. if data:
  310. from kitty.fast_data_types import get_options
  311. cmd = command_for_open(get_options().open_url_with)
  312. open_cmd(cmd, data)
  313. if __name__ == '__main__':
  314. main(sys.argv)