123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- """
- Cloudflared Integration tests
- """
- import unittest
- import subprocess
- import os
- import tempfile
- from contextlib import contextmanager
- from pexpect import pxssh
- class TestSSHBase(unittest.TestCase):
- """
- SSH test base class containing constants and helper funcs
- """
- HOSTNAME = os.environ["SSH_HOSTNAME"]
- SSH_USER = os.environ["SSH_USER"]
- SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
- AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
- SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
- SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
- @classmethod
- def get_ssh_command(cls, pty=True):
- """
- Return ssh command arg list. If pty is true, a PTY is forced for the session.
- """
- cmd = [
- "ssh",
- "-o",
- "StrictHostKeyChecking=no",
- "-F",
- cls.AUTHORIZED_KEYS_SSH_CONFIG,
- cls.SSH_TARGET,
- ]
- if not pty:
- cmd += ["-T"]
- else:
- cmd += ["-tt"]
- return cmd
- @classmethod
- @contextmanager
- def ssh_session_manager(cls, *args, **kwargs):
- """
- Context manager for interacting with a pxssh session.
- Disables pty echo on the remote server and ensures session is terminated afterward.
- """
- session = pxssh.pxssh(options=cls.SSH_OPTIONS)
- session.login(
- cls.HOSTNAME,
- username=cls.SSH_USER,
- original_prompt=r"[#@$]",
- ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
- ssh_tunnels=kwargs.get("ssh_tunnels", {}),
- )
- try:
- session.sendline("stty -echo")
- session.prompt()
- yield session
- finally:
- session.logout()
- @staticmethod
- def get_command_output(session, cmd):
- """
- Executes command on remote ssh server and waits for prompt.
- Returns command output
- """
- session.sendline(cmd)
- session.prompt()
- return session.before.decode().strip()
- def exec_command(self, cmd, shell=False):
- """
- Executes command locally. Raises Assertion error for non-zero return code.
- Returns stdout and stderr
- """
- proc = subprocess.Popen(
- cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
- )
- raw_out, raw_err = proc.communicate()
- out = raw_out.decode()
- err = raw_err.decode()
- self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
- return out.strip(), err.strip()
- class TestSSHCommandExec(TestSSHBase):
- """
- Tests inline ssh command exec
- """
- # Name of file to be downloaded over SCP on remote server.
- REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
- @classmethod
- def get_scp_base_command(cls):
- return [
- "scp",
- "-o",
- "StrictHostKeyChecking=no",
- "-v",
- "-F",
- cls.AUTHORIZED_KEYS_SSH_CONFIG,
- ]
- @unittest.skip(
- "This creates files on the remote. Should be skipped until server is dockerized."
- )
- def test_verbose_scp_sink_mode(self):
- with tempfile.NamedTemporaryFile() as fl:
- self.exec_command(
- self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
- )
- def test_verbose_scp_source_mode(self):
- with tempfile.TemporaryDirectory() as tmpdirname:
- self.exec_command(
- self.get_scp_base_command()
- + [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
- )
- local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
- self.assertTrue(os.path.exists(local_filename))
- self.assertTrue(os.path.getsize(local_filename) > 0)
- def test_pty_command(self):
- base_cmd = self.get_ssh_command()
- out, _ = self.exec_command(base_cmd + ["whoami"])
- self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
- out, _ = self.exec_command(base_cmd + ["tty"])
- self.assertNotEqual(out, "not a tty")
- def test_non_pty_command(self):
- base_cmd = self.get_ssh_command(pty=False)
- out, _ = self.exec_command(base_cmd + ["whoami"])
- self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
- out, _ = self.exec_command(base_cmd + ["tty"])
- self.assertEqual(out, "not a tty")
- class TestSSHShell(TestSSHBase):
- """
- Tests interactive SSH shell
- """
- # File path to a file on the remote server with root only read privileges.
- ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
- def test_ssh_pty(self):
- with self.ssh_session_manager() as session:
- # Test shell launched as correct user
- username = self.get_command_output(session, "whoami")
- self.assertEqual(username.lower(), self.SSH_USER.lower())
- # Test USER env variable set
- user_var = self.get_command_output(session, "echo $USER")
- self.assertEqual(user_var.lower(), self.SSH_USER.lower())
- # Test HOME env variable set to true user home.
- home_env = self.get_command_output(session, "echo $HOME")
- pwd = self.get_command_output(session, "pwd")
- self.assertEqual(pwd, home_env)
- # Test shell launched in correct user home dir.
- self.assertIn(username, pwd)
- # Ensure shell launched with correct user's permissions and privs.
- # Cant read root owned 0700 files.
- output = self.get_command_output(
- session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
- )
- self.assertIn("Permission denied", output)
- def test_short_lived_cert_auth(self):
- with self.ssh_session_manager(
- ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
- ) as session:
- username = self.get_command_output(session, "whoami")
- self.assertEqual(username.lower(), self.SSH_USER.lower())
- unittest.main()
|