tests.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. """
  2. Cloudflared Integration tests
  3. """
  4. import unittest
  5. import subprocess
  6. import os
  7. import tempfile
  8. from contextlib import contextmanager
  9. from pexpect import pxssh
  10. class TestSSHBase(unittest.TestCase):
  11. """
  12. SSH test base class containing constants and helper funcs
  13. """
  14. HOSTNAME = os.environ["SSH_HOSTNAME"]
  15. SSH_USER = os.environ["SSH_USER"]
  16. SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
  17. AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
  18. SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
  19. SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
  20. @classmethod
  21. def get_ssh_command(cls, pty=True):
  22. """
  23. Return ssh command arg list. If pty is true, a PTY is forced for the session.
  24. """
  25. cmd = [
  26. "ssh",
  27. "-o",
  28. "StrictHostKeyChecking=no",
  29. "-F",
  30. cls.AUTHORIZED_KEYS_SSH_CONFIG,
  31. cls.SSH_TARGET,
  32. ]
  33. if not pty:
  34. cmd += ["-T"]
  35. else:
  36. cmd += ["-tt"]
  37. return cmd
  38. @classmethod
  39. @contextmanager
  40. def ssh_session_manager(cls, *args, **kwargs):
  41. """
  42. Context manager for interacting with a pxssh session.
  43. Disables pty echo on the remote server and ensures session is terminated afterward.
  44. """
  45. session = pxssh.pxssh(options=cls.SSH_OPTIONS)
  46. session.login(
  47. cls.HOSTNAME,
  48. username=cls.SSH_USER,
  49. original_prompt=r"[#@$]",
  50. ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
  51. ssh_tunnels=kwargs.get("ssh_tunnels", {}),
  52. )
  53. try:
  54. session.sendline("stty -echo")
  55. session.prompt()
  56. yield session
  57. finally:
  58. session.logout()
  59. @staticmethod
  60. def get_command_output(session, cmd):
  61. """
  62. Executes command on remote ssh server and waits for prompt.
  63. Returns command output
  64. """
  65. session.sendline(cmd)
  66. session.prompt()
  67. return session.before.decode().strip()
  68. def exec_command(self, cmd, shell=False):
  69. """
  70. Executes command locally. Raises Assertion error for non-zero return code.
  71. Returns stdout and stderr
  72. """
  73. proc = subprocess.Popen(
  74. cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
  75. )
  76. raw_out, raw_err = proc.communicate()
  77. out = raw_out.decode()
  78. err = raw_err.decode()
  79. self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
  80. return out.strip(), err.strip()
  81. class TestSSHCommandExec(TestSSHBase):
  82. """
  83. Tests inline ssh command exec
  84. """
  85. # Name of file to be downloaded over SCP on remote server.
  86. REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
  87. @classmethod
  88. def get_scp_base_command(cls):
  89. return [
  90. "scp",
  91. "-o",
  92. "StrictHostKeyChecking=no",
  93. "-v",
  94. "-F",
  95. cls.AUTHORIZED_KEYS_SSH_CONFIG,
  96. ]
  97. @unittest.skip(
  98. "This creates files on the remote. Should be skipped until server is dockerized."
  99. )
  100. def test_verbose_scp_sink_mode(self):
  101. with tempfile.NamedTemporaryFile() as fl:
  102. self.exec_command(
  103. self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
  104. )
  105. def test_verbose_scp_source_mode(self):
  106. with tempfile.TemporaryDirectory() as tmpdirname:
  107. self.exec_command(
  108. self.get_scp_base_command()
  109. + [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
  110. )
  111. local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
  112. self.assertTrue(os.path.exists(local_filename))
  113. self.assertTrue(os.path.getsize(local_filename) > 0)
  114. def test_pty_command(self):
  115. base_cmd = self.get_ssh_command()
  116. out, _ = self.exec_command(base_cmd + ["whoami"])
  117. self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
  118. out, _ = self.exec_command(base_cmd + ["tty"])
  119. self.assertNotEqual(out, "not a tty")
  120. def test_non_pty_command(self):
  121. base_cmd = self.get_ssh_command(pty=False)
  122. out, _ = self.exec_command(base_cmd + ["whoami"])
  123. self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
  124. out, _ = self.exec_command(base_cmd + ["tty"])
  125. self.assertEqual(out, "not a tty")
  126. class TestSSHShell(TestSSHBase):
  127. """
  128. Tests interactive SSH shell
  129. """
  130. # File path to a file on the remote server with root only read privileges.
  131. ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
  132. def test_ssh_pty(self):
  133. with self.ssh_session_manager() as session:
  134. # Test shell launched as correct user
  135. username = self.get_command_output(session, "whoami")
  136. self.assertEqual(username.lower(), self.SSH_USER.lower())
  137. # Test USER env variable set
  138. user_var = self.get_command_output(session, "echo $USER")
  139. self.assertEqual(user_var.lower(), self.SSH_USER.lower())
  140. # Test HOME env variable set to true user home.
  141. home_env = self.get_command_output(session, "echo $HOME")
  142. pwd = self.get_command_output(session, "pwd")
  143. self.assertEqual(pwd, home_env)
  144. # Test shell launched in correct user home dir.
  145. self.assertIn(username, pwd)
  146. # Ensure shell launched with correct user's permissions and privs.
  147. # Cant read root owned 0700 files.
  148. output = self.get_command_output(
  149. session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
  150. )
  151. self.assertIn("Permission denied", output)
  152. def test_short_lived_cert_auth(self):
  153. with self.ssh_session_manager(
  154. ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
  155. ) as session:
  156. username = self.get_command_output(session, "whoami")
  157. self.assertEqual(username.lower(), self.SSH_USER.lower())
  158. unittest.main()