agenttest.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #!/usr/bin/python3
  2. import sys
  3. import os
  4. import socket
  5. import base64
  6. import itertools
  7. import collections
  8. from ssh import *
  9. import agenttestdata
  10. assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
  11. test_session_id = b'Test16ByteSessId'
  12. assert len(test_session_id) == 16
  13. test_message_to_sign = b'test message to sign'
  14. TestSig2 = collections.namedtuple("TestSig2", "flags sig")
  15. class Key2(collections.namedtuple("Key2", "comment public sigs openssh")):
  16. def public_only(self):
  17. return Key2(self.comment, self.public, None, None)
  18. def Add(self):
  19. alg = ssh_decode_string(self.public)
  20. msg = (ssh_byte(SSH2_AGENTC_ADD_IDENTITY) +
  21. ssh_string(alg) +
  22. self.openssh +
  23. ssh_string(self.comment))
  24. return agent_query(msg)
  25. verb = "sign"
  26. def Use(self, flags):
  27. msg = (ssh_byte(SSH2_AGENTC_SIGN_REQUEST) +
  28. ssh_string(self.public) +
  29. ssh_string(test_message_to_sign))
  30. if flags is not None:
  31. msg += ssh_uint32(flags)
  32. rsp = agent_query(msg)
  33. t, rsp = ssh_decode_byte(rsp, True)
  34. assert t == SSH2_AGENT_SIGN_RESPONSE
  35. sig, rsp = ssh_decode_string(rsp, True)
  36. assert len(rsp) == 0
  37. return sig
  38. def Del(self):
  39. msg = (ssh_byte(SSH2_AGENTC_REMOVE_IDENTITY) +
  40. ssh_string(self.public))
  41. return agent_query(msg)
  42. @staticmethod
  43. def DelAll():
  44. msg = (ssh_byte(SSH2_AGENTC_REMOVE_ALL_IDENTITIES))
  45. return agent_query(msg)
  46. @staticmethod
  47. def List():
  48. msg = (ssh_byte(SSH2_AGENTC_REQUEST_IDENTITIES))
  49. rsp = agent_query(msg)
  50. t, rsp = ssh_decode_byte(rsp, True)
  51. assert t == SSH2_AGENT_IDENTITIES_ANSWER
  52. nk, rsp = ssh_decode_uint32(rsp, True)
  53. keylist = []
  54. for _ in range(nk):
  55. p, rsp = ssh_decode_string(rsp, True)
  56. c, rsp = ssh_decode_string(rsp, True)
  57. keylist.append(Key2(c, p, None, None))
  58. assert len(rsp) == 0
  59. return keylist
  60. @classmethod
  61. def make_examples(cls):
  62. cls.examples = agenttestdata.key2examples(cls, TestSig2)
  63. def iter_testsigs(self):
  64. for testsig in self.sigs:
  65. if testsig.flags == 0:
  66. yield testsig._replace(flags=None)
  67. yield testsig
  68. def iter_tests(self):
  69. for testsig in self.iter_testsigs():
  70. yield ([testsig.flags],
  71. " (flags={})".format(testsig.flags),
  72. testsig.sig)
  73. class Key1(collections.namedtuple(
  74. "Key1", "comment public challenge response private")):
  75. def public_only(self):
  76. return Key1(self.comment, self.public, None, None, None)
  77. def Add(self):
  78. msg = (ssh_byte(SSH1_AGENTC_ADD_RSA_IDENTITY) +
  79. self.private +
  80. ssh_string(self.comment))
  81. return agent_query(msg)
  82. verb = "decrypt"
  83. def Use(self, challenge):
  84. msg = (ssh_byte(SSH1_AGENTC_RSA_CHALLENGE) +
  85. self.public +
  86. ssh1_mpint(challenge) +
  87. test_session_id +
  88. ssh_uint32(1))
  89. rsp = agent_query(msg)
  90. t, rsp = ssh_decode_byte(rsp, True)
  91. assert t == SSH1_AGENT_RSA_RESPONSE
  92. assert len(rsp) == 16
  93. return rsp
  94. def Del(self):
  95. msg = (ssh_byte(SSH1_AGENTC_REMOVE_RSA_IDENTITY) +
  96. self.public)
  97. return agent_query(msg)
  98. @staticmethod
  99. def DelAll():
  100. msg = (ssh_byte(SSH1_AGENTC_REMOVE_ALL_RSA_IDENTITIES))
  101. return agent_query(msg)
  102. @staticmethod
  103. def List():
  104. msg = (ssh_byte(SSH1_AGENTC_REQUEST_RSA_IDENTITIES))
  105. rsp = agent_query(msg)
  106. t, rsp = ssh_decode_byte(rsp, True)
  107. assert t == SSH1_AGENT_RSA_IDENTITIES_ANSWER
  108. nk, rsp = ssh_decode_uint32(rsp, True)
  109. keylist = []
  110. for _ in range(nk):
  111. b, rsp = ssh_decode_uint32(rsp, True)
  112. e, rsp = ssh1_get_mpint(rsp, True)
  113. m, rsp = ssh1_get_mpint(rsp, True)
  114. c, rsp = ssh_decode_string(rsp, True)
  115. keylist.append(Key1(c, ssh_uint32(b)+e+m, None, None, None))
  116. assert len(rsp) == 0
  117. return keylist
  118. @classmethod
  119. def make_examples(cls):
  120. cls.examples = agenttestdata.key1examples(cls)
  121. def iter_tests(self):
  122. yield [self.challenge], "", self.response
  123. def agent_query(msg):
  124. msg = ssh_string(msg)
  125. s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  126. s.connect(os.environ["SSH_AUTH_SOCK"])
  127. s.send(msg)
  128. length = ssh_decode_uint32(s.recv(4))
  129. assert length < AGENT_MAX_MSGLEN
  130. return s.recv(length)
  131. def enumerate_bits(iterable):
  132. return ((1<<j, item) for j,item in enumerate(iterable))
  133. def gray_code(nbits):
  134. old = 0
  135. for i in itertools.chain(range(1, 1 << nbits), [0]):
  136. new = i ^ (i>>1)
  137. diff = new ^ old
  138. assert diff != 0 and (diff & (diff-1)) == 0
  139. yield old, new, diff
  140. old = new
  141. assert old == 0
  142. class TestRunner:
  143. def __init__(self):
  144. self.ok = True
  145. @staticmethod
  146. def fmt_response(response):
  147. return "'{}'".format(
  148. base64.encodebytes(response).decode("ASCII").replace("\n",""))
  149. @staticmethod
  150. def fmt_keylist(keys):
  151. return "{{{}}}".format(
  152. ",".join(key.comment.decode("ASCII") for key in sorted(keys)))
  153. def expect_success(self, text, response):
  154. if response == ssh_byte(SSH_AGENT_SUCCESS):
  155. print(text, "=> success")
  156. elif response == ssh_byte(SSH_AGENT_FAILURE):
  157. print("FAIL!", text, "=> failure")
  158. self.ok = False
  159. else:
  160. print("FAIL!", text, "=>", self.fmt_response(response))
  161. self.ok = False
  162. def check_keylist(self, K, expected_keys):
  163. keys = K.List()
  164. print("list keys =>", self.fmt_keylist(keys))
  165. if set(keys) != set(expected_keys):
  166. print("FAIL! Should have been", self.fmt_keylist(expected_keys))
  167. self.ok = False
  168. def gray_code_test(self, K):
  169. bks = list(enumerate_bits(K.examples))
  170. self.check_keylist(K, {})
  171. for old, new, diff in gray_code(len(K.examples)):
  172. bit, key = next((bit, key) for bit, key in bks if diff & bit)
  173. if new & bit:
  174. self.expect_success("insert " + key.comment.decode("ASCII"),
  175. key.Add())
  176. else:
  177. self.expect_success("delete " + key.comment.decode("ASCII"),
  178. key.Del())
  179. self.check_keylist(K, [key.public_only() for bit, key in bks
  180. if new & bit])
  181. def sign_test(self, K):
  182. for key in K.examples:
  183. for params, message, expected_answer in key.iter_tests():
  184. key.Add()
  185. actual_answer = key.Use(*params)
  186. key.Del()
  187. record = "{} with {}{}".format(
  188. K.verb, key.comment.decode("ASCII"), message)
  189. if actual_answer == expected_answer:
  190. print(record, "=> success")
  191. else:
  192. print("FAIL!", record, "=> {} but expected {}".format(
  193. self.fmt_response(actual_answer),
  194. self.fmt_response(expected_answer)))
  195. self.ok = False
  196. def run(self):
  197. self.expect_success("init: delete all ssh2 keys", Key2.DelAll())
  198. for K in [Key2, Key1]:
  199. self.gray_code_test(K)
  200. self.sign_test(K)
  201. # TODO: negative tests of all kinds.
  202. def main():
  203. Key2.make_examples()
  204. Key1.make_examples()
  205. tr = TestRunner()
  206. tr.run()
  207. if tr.ok:
  208. print("Test run passed")
  209. else:
  210. sys.exit("Test run failed!")
  211. if __name__ == "__main__":
  212. main()