ssh.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import sys
  2. import struct
  3. import itertools
  4. assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
  5. def nbits(n):
  6. # Mimic mp_get_nbits for ordinary Python integers.
  7. assert 0 <= n
  8. smax = next(s for s in itertools.count() if (n >> (1 << s)) == 0)
  9. toret = 0
  10. for shift in reversed([1 << s for s in range(smax)]):
  11. if n >> shift != 0:
  12. n >>= shift
  13. toret += shift
  14. assert n <= 1
  15. if n == 1:
  16. toret += 1
  17. return toret
  18. def ssh_byte(n):
  19. return struct.pack("B", n)
  20. def ssh_uint32(n):
  21. return struct.pack(">L", n)
  22. def ssh_string(s):
  23. return ssh_uint32(len(s)) + s
  24. def ssh1_mpint(x):
  25. bits = nbits(x)
  26. bytevals = [0xFF & (x >> (8*n)) for n in range((bits-1)//8, -1, -1)]
  27. return struct.pack(">H" + "B" * len(bytevals), bits, *bytevals)
  28. def ssh2_mpint(x):
  29. bytevals = [0xFF & (x >> (8*n)) for n in range(nbits(x)//8, -1, -1)]
  30. return struct.pack(">L" + "B" * len(bytevals), len(bytevals), *bytevals)
  31. def decoder(fn):
  32. def decode(s, return_rest = False):
  33. item, length_consumed = fn(s)
  34. if return_rest:
  35. return item, s[length_consumed:]
  36. else:
  37. return item
  38. return decode
  39. @decoder
  40. def ssh_decode_byte(s):
  41. return struct.unpack_from("B", s, 0)[0], 1
  42. @decoder
  43. def ssh_decode_uint32(s):
  44. return struct.unpack_from(">L", s, 0)[0], 4
  45. @decoder
  46. def ssh_decode_string(s):
  47. length = ssh_decode_uint32(s)
  48. assert length + 4 <= len(s)
  49. return s[4:length+4], length+4
  50. @decoder
  51. def ssh1_get_mpint(s): # returns it unconsumed, still in wire encoding
  52. nbits = struct.unpack_from(">H", s, 0)[0]
  53. nbytes = (nbits + 7) // 8
  54. assert nbytes + 2 <= len(s)
  55. return s[:nbytes+2], nbytes+2
  56. @decoder
  57. def ssh1_decode_mpint(s):
  58. nbits = struct.unpack_from(">H", s, 0)[0]
  59. nbytes = (nbits + 7) // 8
  60. assert nbytes + 2 <= len(s)
  61. data = s[2:nbytes+2]
  62. v = 0
  63. for b in struct.unpack("B" * len(data), data):
  64. v = (v << 8) | b
  65. return v, nbytes+2
  66. AGENT_MAX_MSGLEN = 262144
  67. SSH1_AGENTC_REQUEST_RSA_IDENTITIES = 1
  68. SSH1_AGENT_RSA_IDENTITIES_ANSWER = 2
  69. SSH1_AGENTC_RSA_CHALLENGE = 3
  70. SSH1_AGENT_RSA_RESPONSE = 4
  71. SSH1_AGENTC_ADD_RSA_IDENTITY = 7
  72. SSH1_AGENTC_REMOVE_RSA_IDENTITY = 8
  73. SSH1_AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
  74. SSH_AGENT_FAILURE = 5
  75. SSH_AGENT_SUCCESS = 6
  76. SSH2_AGENTC_REQUEST_IDENTITIES = 11
  77. SSH2_AGENT_IDENTITIES_ANSWER = 12
  78. SSH2_AGENTC_SIGN_REQUEST = 13
  79. SSH2_AGENT_SIGN_RESPONSE = 14
  80. SSH2_AGENTC_ADD_IDENTITY = 17
  81. SSH2_AGENTC_REMOVE_IDENTITY = 18
  82. SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
  83. SSH2_AGENTC_EXTENSION = 27
  84. SSH_AGENT_RSA_SHA2_256 = 2
  85. SSH_AGENT_RSA_SHA2_512 = 4