utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. #!/usr/bin/env python
  2. # License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
  3. import os
  4. import subprocess
  5. import traceback
  6. from contextlib import suppress
  7. from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple
  8. from kitty.types import run_once
  9. from kitty.utils import SSHConnectionData
  10. @run_once
  11. def ssh_options() -> Dict[str, str]:
  12. try:
  13. p = subprocess.run(['ssh'], stderr=subprocess.PIPE, encoding='utf-8')
  14. raw = p.stderr or ''
  15. except FileNotFoundError:
  16. return {
  17. '4': '', '6': '', 'A': '', 'a': '', 'C': '', 'f': '', 'G': '', 'g': '', 'K': '', 'k': '',
  18. 'M': '', 'N': '', 'n': '', 'q': '', 's': '', 'T': '', 't': '', 'V': '', 'v': '', 'X': '',
  19. 'x': '', 'Y': '', 'y': '', 'B': 'bind_interface', 'b': 'bind_address', 'c': 'cipher_spec',
  20. 'D': '[bind_address:]port', 'E': 'log_file', 'e': 'escape_char', 'F': 'configfile', 'I': 'pkcs11',
  21. 'i': 'identity_file', 'J': '[user@]host[:port]', 'L': 'address', 'l': 'login_name', 'm': 'mac_spec',
  22. 'O': 'ctl_cmd', 'o': 'option', 'p': 'port', 'Q': 'query_option', 'R': 'address',
  23. 'S': 'ctl_path', 'W': 'host:port', 'w': 'local_tun[:remote_tun]'
  24. }
  25. ans: Dict[str, str] = {}
  26. pos = 0
  27. while True:
  28. pos = raw.find('[', pos)
  29. if pos < 0:
  30. break
  31. num = 1
  32. epos = pos
  33. while num > 0:
  34. epos += 1
  35. if raw[epos] not in '[]':
  36. continue
  37. num += 1 if raw[epos] == '[' else -1
  38. q = raw[pos+1:epos]
  39. pos = epos
  40. if len(q) < 2 or q[0] != '-':
  41. continue
  42. if ' ' in q:
  43. opt, desc = q.split(' ', 1)
  44. ans[opt[1:]] = desc
  45. else:
  46. ans.update(dict.fromkeys(q[1:], ''))
  47. return ans
  48. def is_kitten_cmdline(q: Sequence[str]) -> bool:
  49. if not q:
  50. return False
  51. exe_name = os.path.basename(q[0]).lower()
  52. if exe_name == 'kitten' and q[1:2] == ['ssh']:
  53. return True
  54. if len(q) < 4:
  55. return False
  56. if exe_name != 'kitty':
  57. return False
  58. if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']:
  59. return True
  60. return q[1:3] == ['+runpy', 'from kittens.runner import main; main()'] and len(q) >= 6 and q[5] == 'ssh'
  61. def patch_cmdline(key: str, val: str, argv: List[str]) -> None:
  62. for i, arg in enumerate(tuple(argv)):
  63. if arg.startswith(f'--kitten={key}='):
  64. argv[i] = f'--kitten={key}={val}'
  65. return
  66. elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'{key}=') or arg.startswith(f'{key} ')):
  67. argv[i] = f'{key}={val}'
  68. return
  69. idx = argv.index('ssh')
  70. argv.insert(idx + 1, f'--kitten={key}={val}')
  71. def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None:
  72. patch_cmdline('cwd', cwd, argv)
  73. def create_shared_memory(data: Any, prefix: str) -> str:
  74. import atexit
  75. import json
  76. from kitty.shm import SharedMemory
  77. db = json.dumps(data).encode('utf-8')
  78. with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm:
  79. shm.write_data_with_size(db)
  80. shm.flush()
  81. atexit.register(shm.unlink)
  82. return shm.name
  83. def read_data_from_shared_memory(shm_name: str) -> Any:
  84. import json
  85. import stat
  86. from kitty.shm import SharedMemory
  87. with SharedMemory(shm_name, readonly=True) as shm:
  88. shm.unlink()
  89. if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
  90. raise ValueError(f'Incorrect owner on pwfile: uid={shm.stats.st_uid} gid={shm.stats.st_gid}')
  91. mode = stat.S_IMODE(shm.stats.st_mode)
  92. if mode != stat.S_IREAD | stat.S_IWRITE:
  93. raise ValueError(f'Incorrect permissions on pwfile: 0o{mode:03o}')
  94. return json.loads(shm.read_data_with_size())
  95. def get_ssh_data(msgb: memoryview, request_id: str) -> Iterator[bytes]:
  96. from base64 import standard_b64decode
  97. yield b'\nKITTY_DATA_START\n' # to discard leading data
  98. try:
  99. msg = standard_b64decode(msgb).decode('utf-8')
  100. md = dict(x.split('=', 1) for x in msg.split(':'))
  101. pw = md['pw']
  102. pwfilename = md['pwfile']
  103. rq_id = md['id']
  104. except Exception:
  105. traceback.print_exc()
  106. yield b'invalid ssh data request message\n'
  107. else:
  108. try:
  109. env_data = read_data_from_shared_memory(pwfilename)
  110. if pw != env_data['pw']:
  111. raise ValueError('Incorrect password')
  112. if rq_id != request_id:
  113. raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
  114. except Exception as e:
  115. traceback.print_exc()
  116. yield f'{e}\n'.encode('utf-8')
  117. else:
  118. yield b'OK\n'
  119. encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
  120. # macOS has a 255 byte limit on its input queue as per man stty.
  121. # Not clear if that applies to canonical mode input as well, but
  122. # better to be safe.
  123. line_sz = 254
  124. while encoded_data:
  125. yield encoded_data[:line_sz]
  126. yield b'\n'
  127. encoded_data = encoded_data[line_sz:]
  128. yield b'KITTY_DATA_END\n'
  129. def set_env_in_cmdline(env: Dict[str, str], argv: List[str], clone: bool = True) -> None:
  130. from kitty.options.utils import DELETE_ENV_VAR
  131. if clone:
  132. patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
  133. return
  134. idx = argv.index('ssh')
  135. for i in range(idx, len(argv)):
  136. if argv[i] == '--kitten':
  137. idx = i + 1
  138. elif argv[i].startswith('--kitten='):
  139. idx = i
  140. env_dirs = []
  141. for k, v in env.items():
  142. if v is DELETE_ENV_VAR:
  143. x = f'--kitten=env={k}'
  144. else:
  145. x = f'--kitten=env={k}={v}'
  146. env_dirs.append(x)
  147. argv[idx+1:idx+1] = env_dirs
  148. def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
  149. other_ssh_args: Set[str] = set()
  150. boolean_ssh_args: Set[str] = set()
  151. for k, v in ssh_options().items():
  152. k = f'-{k}'
  153. if v:
  154. other_ssh_args.add(k)
  155. else:
  156. boolean_ssh_args.add(k)
  157. return boolean_ssh_args, other_ssh_args
  158. def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
  159. for x in extra_args:
  160. if arg == x or arg.startswith(f'{x}='):
  161. return x
  162. return ''
  163. passthrough_args = {f'-{x}' for x in 'NnfGT'}
  164. def set_server_args_in_cmdline(
  165. server_args: List[str], argv: List[str],
  166. extra_args: Tuple[str, ...] = ('--kitten',),
  167. allocate_tty: bool = False
  168. ) -> None:
  169. boolean_ssh_args, other_ssh_args = get_ssh_cli()
  170. ssh_args = []
  171. expecting_option_val = False
  172. found_extra_args: List[str] = []
  173. expecting_extra_val = ''
  174. ans = list(argv)
  175. found_ssh = False
  176. for i, argument in enumerate(argv):
  177. if not found_ssh:
  178. found_ssh = argument == 'ssh'
  179. continue
  180. if argument.startswith('-') and not expecting_option_val:
  181. if argument == '--':
  182. del ans[i+2:]
  183. if allocate_tty and ans[i-1] != '-t':
  184. ans.insert(i, '-t')
  185. break
  186. if extra_args:
  187. matching_ex = is_extra_arg(argument, extra_args)
  188. if matching_ex:
  189. if '=' in argument:
  190. exval = argument.partition('=')[-1]
  191. found_extra_args.extend((matching_ex, exval))
  192. else:
  193. expecting_extra_val = matching_ex
  194. expecting_option_val = True
  195. continue
  196. # could be a multi-character option
  197. all_args = argument[1:]
  198. for i, arg in enumerate(all_args):
  199. arg = f'-{arg}'
  200. if arg in boolean_ssh_args:
  201. ssh_args.append(arg)
  202. continue
  203. if arg in other_ssh_args:
  204. ssh_args.append(arg)
  205. rest = all_args[i+1:]
  206. if rest:
  207. ssh_args.append(rest)
  208. else:
  209. expecting_option_val = True
  210. break
  211. raise KeyError(f'unknown option -- {arg[1:]}')
  212. continue
  213. if expecting_option_val:
  214. if expecting_extra_val:
  215. found_extra_args.extend((expecting_extra_val, argument))
  216. expecting_extra_val = ''
  217. else:
  218. ssh_args.append(argument)
  219. expecting_option_val = False
  220. continue
  221. del ans[i+1:]
  222. if allocate_tty and ans[i] != '-t':
  223. ans.insert(i, '-t')
  224. break
  225. argv[:] = ans + server_args
  226. def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
  227. boolean_ssh_args, other_ssh_args = get_ssh_cli()
  228. port: Optional[int] = None
  229. expecting_port = expecting_identity = False
  230. expecting_option_val = False
  231. expecting_hostname = False
  232. expecting_extra_val = ''
  233. host_name = identity_file = found_ssh = ''
  234. found_extra_args: List[Tuple[str, str]] = []
  235. for i, arg in enumerate(args):
  236. if not found_ssh:
  237. if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
  238. found_ssh = arg
  239. continue
  240. if expecting_hostname:
  241. host_name = arg
  242. continue
  243. if arg.startswith('-') and not expecting_option_val:
  244. if arg in boolean_ssh_args:
  245. continue
  246. if arg == '--':
  247. expecting_hostname = True
  248. if arg.startswith('-p'):
  249. if arg[2:].isdigit():
  250. with suppress(Exception):
  251. port = int(arg[2:])
  252. continue
  253. elif arg == '-p':
  254. expecting_port = True
  255. elif arg.startswith('-i'):
  256. if arg == '-i':
  257. expecting_identity = True
  258. else:
  259. identity_file = arg[2:]
  260. continue
  261. if arg.startswith('--') and extra_args:
  262. matching_ex = is_extra_arg(arg, extra_args)
  263. if matching_ex:
  264. if '=' in arg:
  265. exval = arg.partition('=')[-1]
  266. found_extra_args.append((matching_ex, exval))
  267. continue
  268. expecting_extra_val = matching_ex
  269. expecting_option_val = True
  270. continue
  271. if expecting_option_val:
  272. if expecting_port:
  273. with suppress(Exception):
  274. port = int(arg)
  275. expecting_port = False
  276. elif expecting_identity:
  277. identity_file = arg
  278. elif expecting_extra_val:
  279. found_extra_args.append((expecting_extra_val, arg))
  280. expecting_extra_val = ''
  281. expecting_option_val = False
  282. continue
  283. if not host_name:
  284. host_name = arg
  285. if not host_name:
  286. return None
  287. if host_name.startswith('ssh://'):
  288. from urllib.parse import urlparse
  289. purl = urlparse(host_name)
  290. if purl.hostname:
  291. host_name = purl.hostname
  292. if purl.username:
  293. host_name = f'{purl.username}@{host_name}'
  294. if port is None and purl.port:
  295. port = purl.port
  296. if identity_file:
  297. if not os.path.isabs(identity_file):
  298. identity_file = os.path.expanduser(identity_file)
  299. if not os.path.isabs(identity_file):
  300. identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
  301. return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))