123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- #!/usr/bin/env python3
- # Convert OpenSSH known_hosts and known_hosts2 files to "new format" PuTTY
- # host keys.
- # usage:
- # kh2reg.py [ --win ] known_hosts1 2 3 4 ... > hosts.reg
- # Creates a Windows .REG file (double-click to install).
- # kh2reg.py --unix known_hosts1 2 3 4 ... > sshhostkeys
- # Creates data suitable for storing in ~/.putty/sshhostkeys (Unix).
- # Line endings are someone else's problem as is traditional.
- # Should run under either Python 2 or 3.
- import fileinput
- import base64
- import struct
- import string
- import re
- import sys
- import argparse
- import itertools
- import collections
- import hashlib
- from functools import reduce
- def winmungestr(s):
- "Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
- candot = 0
- r = ""
- for c in s:
- if c in r' \*?%~' or ord(c)<ord(' ') or (c == '.' and not candot):
- r = r + ("%%%02X" % ord(c))
- else:
- r = r + c
- candot = 1
- return r
- def strtoint(s):
- "Convert arbitrary-length big-endian binary data to a Python int"
- bytes = struct.unpack(">{:d}B".format(len(s)), s)
- return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)
- def strtoint_le(s):
- "Convert arbitrary-length little-endian binary data to a Python int"
- bytes = reversed(struct.unpack(">{:d}B".format(len(s)), s))
- return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)
- def inttohex(n):
- "Convert int to lower-case hex."
- return "0x{:x}".format(n)
- def warn(s):
- "Warning with file/line number"
- sys.stderr.write("%s:%d: %s\n"
- % (fileinput.filename(), fileinput.filelineno(), s))
- class HMAC(object):
- def __init__(self, hashclass, blocksize):
- self.hashclass = hashclass
- self.blocksize = blocksize
- self.struct = struct.Struct(">{:d}B".format(self.blocksize))
- def pad_key(self, key):
- return key + b'\0' * (self.blocksize - len(key))
- def xor_key(self, key, xor):
- return self.struct.pack(*[b ^ xor for b in self.struct.unpack(key)])
- def keyed_hash(self, key, padbyte, string):
- return self.hashclass(self.xor_key(key, padbyte) + string).digest()
- def compute(self, key, string):
- if len(key) > self.blocksize:
- key = self.hashclass(key).digest()
- key = self.pad_key(key)
- return self.keyed_hash(key, 0x5C, self.keyed_hash(key, 0x36, string))
- def openssh_hashed_host_match(hashed_host, try_host):
- if hashed_host.startswith(b'|1|'):
- salt, expected = hashed_host[3:].split(b'|')
- salt = base64.decodebytes(salt)
- expected = base64.decodebytes(expected)
- mac = HMAC(hashlib.sha1, 64)
- else:
- return False # unrecognised magic number prefix
- return mac.compute(salt, try_host) == expected
- def invert(n, p):
- """Compute inverse mod p."""
- if n % p == 0:
- raise ZeroDivisionError()
- a = n, 1, 0
- b = p, 0, 1
- while b[0]:
- q = a[0] // b[0]
- a = a[0] - q*b[0], a[1] - q*b[1], a[2] - q*b[2]
- b, a = a, b
- assert abs(a[0]) == 1
- return a[1]*a[0]
- def jacobi(n,m):
- """Compute the Jacobi symbol.
- The special case of this when m is prime is the Legendre symbol,
- which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
- non-zero square number mod m; -1 if n is not congruent to any
- square mod m.
- """
- assert m & 1
- acc = 1
- while True:
- n %= m
- if n == 0:
- return 0
- while not (n & 1):
- n >>= 1
- if (m & 7) not in {1,7}:
- acc *= -1
- if n == 1:
- return acc
- if (n & 3) == 3 and (m & 3) == 3:
- acc *= -1
- n, m = m, n
- class SqrtModP(object):
- """Class for finding square roots of numbers mod p.
- p must be an odd prime (but its primality is not checked)."""
- def __init__(self, p):
- p = abs(p)
- assert p & 1
- self.p = p
- # Decompose p as 2^e k + 1 for odd k.
- self.k = p-1
- self.e = 0
- while not (self.k & 1):
- self.k >>= 1
- self.e += 1
- # Find a non-square mod p.
- for self.z in itertools.count(1):
- if jacobi(self.z, self.p) == -1:
- break
- self.zinv = invert(self.z, self.p)
- def sqrt_recurse(self, a):
- ak = pow(a, self.k, self.p)
- for i in range(self.e, -1, -1):
- if ak == 1:
- break
- ak = ak*ak % self.p
- assert i > 0
- if i == self.e:
- return pow(a, (self.k+1) // 2, self.p)
- r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
- return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p
- def sqrt(self, a):
- j = jacobi(a, self.p)
- if j == 0:
- return 0
- if j < 0:
- raise ValueError("{} has no square root mod {}".format(a, self.p))
- a %= self.p
- r = self.sqrt_recurse(a)
- assert r*r % self.p == a
- # Normalise to the smaller (or 'positive') one of the two roots.
- return min(r, self.p - r)
- def __str__(self):
- return "{}({})".format(type(self).__name__, self.p)
- def __repr__(self):
- return self.__str__()
- instances = {}
- @classmethod
- def make(cls, p):
- if p not in cls.instances:
- cls.instances[p] = cls(p)
- return cls.instances[p]
- @classmethod
- def root(cls, n, p):
- return cls.make(p).sqrt(n)
- NistCurve = collections.namedtuple("NistCurve", "p a b")
- nist_curves = {
- "ecdsa-sha2-nistp256": NistCurve(0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff, 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b),
- "ecdsa-sha2-nistp384": NistCurve(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff, 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc, 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef),
- "ecdsa-sha2-nistp521": NistCurve(0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff, 0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc, 0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00),
- }
- class BlankInputLine(Exception):
- pass
- class UnknownKeyType(Exception):
- def __init__(self, keytype):
- self.keytype = keytype
- class KeyFormatError(Exception):
- def __init__(self, msg):
- self.msg = msg
- def handle_line(line, output_formatter, try_hosts):
- try:
- # Remove leading/trailing whitespace (should zap CR and LF)
- line = line.strip()
- # Skip blanks and comments
- if line == '' or line[0] == '#':
- raise BlankInputLine
- # Split line on spaces.
- fields = line.split(' ')
- # Common fields
- hostpat = fields[0]
- keyparams = [] # placeholder
- keytype = "" # placeholder
- # Grotty heuristic to distinguish known_hosts from known_hosts2:
- # is second field entirely decimal digits?
- if re.match (r"\d*$", fields[1]):
- # Treat as SSH-1-type host key.
- # Format: hostpat bits10 exp10 mod10 comment...
- # (PuTTY doesn't store the number of bits.)
- keyparams = list(map(int, fields[2:4]))
- keytype = "rsa"
- else:
- # Treat as SSH-2-type host key.
- # Format: hostpat keytype keyblob64 comment...
- sshkeytype, blob = fields[1], base64.decodebytes(
- fields[2].encode("ASCII"))
- # 'blob' consists of a number of
- # uint32 N (big-endian)
- # uint8[N] field_data
- subfields = []
- while blob:
- sizefmt = ">L"
- (size,) = struct.unpack (sizefmt, blob[0:4])
- size = int(size) # req'd for slicage
- (data,) = struct.unpack (">%lus" % size, blob[4:size+4])
- subfields.append(data)
- blob = blob [struct.calcsize(sizefmt) + size : ]
- # The first field is keytype again.
- if subfields[0].decode("ASCII") != sshkeytype:
- raise KeyFormatError("""
- outer and embedded key types do not match: '%s', '%s'
- """ % (sshkeytype, subfields[1]))
- # Translate key type string into something PuTTY can use, and
- # munge the rest of the data.
- if sshkeytype == "ssh-rsa":
- keytype = "rsa2"
- # The rest of the subfields we can treat as an opaque list
- # of bignums (same numbers and order as stored by PuTTY).
- keyparams = list(map(strtoint, subfields[1:]))
- elif sshkeytype == "ssh-dss":
- keytype = "dss"
- # Same again.
- keyparams = list(map(strtoint, subfields[1:]))
- elif sshkeytype in nist_curves:
- keytype = sshkeytype
- # Have to parse this a bit.
- if len(subfields) > 3:
- raise KeyFormatError("too many subfields in blob")
- (curvename, Q) = subfields[1:]
- # First is yet another copy of the key name.
- if not re.match("ecdsa-sha2-" + re.escape(
- curvename.decode("ASCII")), sshkeytype):
- raise KeyFormatError("key type mismatch ('%s' vs '%s')"
- % (sshkeytype, curvename))
- # Second contains key material X and Y (hopefully).
- # First a magic octet indicating point compression.
- point_type = struct.unpack_from("B", Q, 0)[0]
- Qrest = Q[1:]
- if point_type == 4:
- # Then two equal-length bignums (X and Y).
- bnlen = len(Qrest)
- if (bnlen % 1) != 0:
- raise KeyFormatError("odd-length X+Y")
- bnlen = bnlen // 2
- x = strtoint(Qrest[:bnlen])
- y = strtoint(Qrest[bnlen:])
- elif 2 <= point_type <= 3:
- # A compressed point just specifies X, and leaves
- # Y implicit except for parity, so we have to
- # recover it from the curve equation.
- curve = nist_curves[sshkeytype]
- x = strtoint(Qrest)
- yy = (x*x*x + curve.a*x + curve.b) % curve.p
- y = SqrtModP.root(yy, curve.p)
- if y % 2 != point_type % 2:
- y = curve.p - y
- keyparams = [curvename, x, y]
- elif sshkeytype in { "ssh-ed25519", "ssh-ed448" }:
- keytype = sshkeytype
- if len(subfields) != 2:
- raise KeyFormatError("wrong number of subfields in blob")
- # Key material y, with the top bit being repurposed as
- # the expected parity of the associated x (point
- # compression).
- y = strtoint_le(subfields[1])
- x_parity = y >> 255
- y &= ~(1 << 255)
- # Curve parameters.
- p, d, a = {
- "ssh-ed25519": (2**255 - 19, 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3, -1),
- "ssh-ed448": (2**448-2**224-1, -39081, +1),
- }[sshkeytype]
- # Recover x^2 = (y^2 - 1) / (d y^2 - a).
- xx = (y*y - 1) * invert(d*y*y - a, p) % p
- # Take the square root.
- x = SqrtModP.root(xx, p)
- # Pick the square root of the correct parity.
- if (x % 2) != x_parity:
- x = p - x
- keyparams = [x, y]
- else:
- raise UnknownKeyType(sshkeytype)
- # Now print out one line per host pattern, discarding wildcards.
- for host in hostpat.split(','):
- if re.search (r"[*?!]", host):
- warn("skipping wildcard host pattern '%s'" % host)
- continue
- if re.match (r"\|", host):
- for try_host in try_hosts:
- if openssh_hashed_host_match(host.encode('ASCII'),
- try_host.encode('UTF-8')):
- host = try_host
- break
- else:
- warn("unable to match hashed hostname '%s'" % host)
- continue
- m = re.match (r"\[([^]]*)\]:(\d*)$", host)
- if m:
- (host, port) = m.group(1,2)
- port = int(port)
- else:
- port = 22
- # Slightly bizarre output key format: 'type@port:hostname'
- # XXX: does PuTTY do anything useful with literal IP[v4]s?
- key = keytype + ("@%d:%s" % (port, host))
- # Most of these are numbers, but there's the occasional
- # string that needs passing through
- value = ",".join(map(
- lambda x: x if isinstance(x, str)
- else x.decode('ASCII') if isinstance(x, bytes)
- else inttohex(x), keyparams))
- output_formatter.key(key, value)
- except UnknownKeyType as k:
- warn("unknown SSH key type '%s', skipping" % k.keytype)
- except KeyFormatError as k:
- warn("trouble parsing key (%s), skipping" % k.msg)
- except BlankInputLine:
- pass
- class OutputFormatter(object):
- def __init__(self, fh):
- self.fh = fh
- def header(self):
- pass
- def trailer(self):
- pass
- class WindowsOutputFormatter(OutputFormatter):
- def header(self):
- # Output REG file header.
- self.fh.write(r"""REGEDIT4
- [HKEY_CURRENT_USER\Software\SimonTatham\PuTTY\SshHostKeys]
- """)
- def key(self, key, value):
- # XXX: worry about double quotes?
- self.fh.write("\"%s\"=\"%s\"\n" % (winmungestr(key), value))
- def trailer(self):
- # The spec at http://support.microsoft.com/kb/310516 says we need
- # a blank line at the end of the reg file:
- #
- # Note the registry file should contain a blank line at the
- # bottom of the file.
- #
- self.fh.write("\n")
- class UnixOutputFormatter(OutputFormatter):
- def key(self, key, value):
- self.fh.write('%s %s\n' % (key, value))
- def main():
- parser = argparse.ArgumentParser(
- description="Convert OpenSSH known hosts files to PuTTY's format.")
- group = parser.add_mutually_exclusive_group()
- group.add_argument(
- "--windows", "--win", action='store_const',
- dest="output_formatter_class", const=WindowsOutputFormatter,
- help="Produce Windows .reg file output that regedit.exe can import"
- " (default).")
- group.add_argument(
- "--unix", action='store_const',
- dest="output_formatter_class", const=UnixOutputFormatter,
- help="Produce a file suitable for use as ~/.putty/sshhostkeys.")
- parser.add_argument("-o", "--output", type=argparse.FileType("w"),
- default=argparse.FileType("w")("-"),
- help="Output file to write to (default stdout).")
- parser.add_argument("--hostname", action="append",
- help="Host name(s) to try matching against hashed "
- "host entries in input.")
- parser.add_argument("infile", nargs="*",
- help="Input file(s) to read from (default stdin).")
- parser.set_defaults(output_formatter_class=WindowsOutputFormatter,
- hostname=[])
- args = parser.parse_args()
- output_formatter = args.output_formatter_class(args.output)
- output_formatter.header()
- for line in fileinput.input(args.infile):
- handle_line(line, output_formatter, args.hostname)
- output_formatter.trailer()
- if __name__ == "__main__":
- main()
|