123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- /*
- * testbn.c: standalone test program for the bignum code.
- */
- /*
- * Accepts input on standard input, in the form generated by
- * testdata/bignum.py.
- */
- #include <stdio.h>
- #include <stdlib.h>
- #include <ctype.h>
- #include "ssh.h"
- #include "sshbn.h"
- void modalfatalbox(const char *p, ...)
- {
- va_list ap;
- fprintf(stderr, "FATAL ERROR: ");
- va_start(ap, p);
- vfprintf(stderr, p, ap);
- va_end(ap);
- fputc('\n', stderr);
- exit(1);
- }
- int random_byte(void)
- {
- modalfatalbox("random_byte called in testbn");
- return 0;
- }
- #define fromxdigit(c) ( (c)>'9' ? ((c)&0xDF) - 'A' + 10 : (c) - '0' )
- int main(int argc, char **argv)
- {
- char *buf;
- int line = 0;
- int passes = 0, fails = 0;
- printf("BIGNUM_INT_BITS = %d\n", (int)BIGNUM_INT_BITS);
- while ((buf = fgetline(stdin)) != NULL) {
- int maxlen = strlen(buf);
- unsigned char *data = snewn(maxlen, unsigned char);
- unsigned char *ptrs[5], *q;
- int ptrnum;
- char *bufp = buf;
- line++;
- q = data;
- ptrnum = 0;
- while (*bufp && !isspace((unsigned char)*bufp))
- bufp++;
- if (bufp)
- *bufp++ = '\0';
- while (*bufp) {
- char *start, *end;
- int i;
- while (*bufp && !isxdigit((unsigned char)*bufp))
- bufp++;
- start = bufp;
- if (!*bufp)
- break;
- while (*bufp && isxdigit((unsigned char)*bufp))
- bufp++;
- end = bufp;
- if (ptrnum >= lenof(ptrs))
- break;
- ptrs[ptrnum++] = q;
-
- for (i = -((end - start) & 1); i < end-start; i += 2) {
- unsigned char val = (i < 0 ? 0 : fromxdigit(start[i]));
- val = val * 16 + fromxdigit(start[i+1]);
- *q++ = val;
- }
- ptrs[ptrnum] = q;
- }
- if (!strcmp(buf, "mul")) {
- Bignum a, b, c, p;
- if (ptrnum != 3) {
- printf("%d: mul with %d parameters, expected 3\n", line, ptrnum);
- exit(1);
- }
- a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
- b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
- c = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
- p = bigmul(a, b);
- if (bignum_cmp(c, p) == 0) {
- passes++;
- } else {
- char *as = bignum_decimal(a);
- char *bs = bignum_decimal(b);
- char *cs = bignum_decimal(c);
- char *ps = bignum_decimal(p);
-
- printf("%d: fail: %s * %s gave %s expected %s\n",
- line, as, bs, ps, cs);
- fails++;
- sfree(as);
- sfree(bs);
- sfree(cs);
- sfree(ps);
- }
- freebn(a);
- freebn(b);
- freebn(c);
- freebn(p);
- } else if (!strcmp(buf, "modmul")) {
- Bignum a, b, m, c, p;
- if (ptrnum != 4) {
- printf("%d: modmul with %d parameters, expected 4\n",
- line, ptrnum);
- exit(1);
- }
- a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
- b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
- m = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
- c = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
- p = modmul(a, b, m);
- if (bignum_cmp(c, p) == 0) {
- passes++;
- } else {
- char *as = bignum_decimal(a);
- char *bs = bignum_decimal(b);
- char *ms = bignum_decimal(m);
- char *cs = bignum_decimal(c);
- char *ps = bignum_decimal(p);
-
- printf("%d: fail: %s * %s mod %s gave %s expected %s\n",
- line, as, bs, ms, ps, cs);
- fails++;
- sfree(as);
- sfree(bs);
- sfree(ms);
- sfree(cs);
- sfree(ps);
- }
- freebn(a);
- freebn(b);
- freebn(m);
- freebn(c);
- freebn(p);
- } else if (!strcmp(buf, "pow")) {
- Bignum base, expt, modulus, expected, answer;
- if (ptrnum != 4) {
- printf("%d: pow with %d parameters, expected 4\n", line, ptrnum);
- exit(1);
- }
- base = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
- expt = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
- modulus = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
- expected = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
- answer = modpow(base, expt, modulus);
- if (bignum_cmp(expected, answer) == 0) {
- passes++;
- } else {
- char *as = bignum_decimal(base);
- char *bs = bignum_decimal(expt);
- char *cs = bignum_decimal(modulus);
- char *ds = bignum_decimal(answer);
- char *ps = bignum_decimal(expected);
-
- printf("%d: fail: %s ^ %s mod %s gave %s expected %s\n",
- line, as, bs, cs, ds, ps);
- fails++;
- sfree(as);
- sfree(bs);
- sfree(cs);
- sfree(ds);
- sfree(ps);
- }
- freebn(base);
- freebn(expt);
- freebn(modulus);
- freebn(expected);
- freebn(answer);
- } else if (!strcmp(buf, "divmod")) {
- Bignum n, d, expect_q, expect_r, answer_q, answer_r;
- int fail;
- if (ptrnum != 4) {
- printf("%d: divmod with %d parameters, expected 4\n", line, ptrnum);
- exit(1);
- }
- n = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
- d = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
- expect_q = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
- expect_r = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
- answer_q = bigdiv(n, d);
- answer_r = bigmod(n, d);
- fail = FALSE;
- if (bignum_cmp(expect_q, answer_q) != 0) {
- char *as = bignum_decimal(n);
- char *bs = bignum_decimal(d);
- char *cs = bignum_decimal(answer_q);
- char *ds = bignum_decimal(expect_q);
- printf("%d: fail: %s / %s gave %s expected %s\n",
- line, as, bs, cs, ds);
- fail = TRUE;
- sfree(as);
- sfree(bs);
- sfree(cs);
- sfree(ds);
- }
- if (bignum_cmp(expect_r, answer_r) != 0) {
- char *as = bignum_decimal(n);
- char *bs = bignum_decimal(d);
- char *cs = bignum_decimal(answer_r);
- char *ds = bignum_decimal(expect_r);
- printf("%d: fail: %s mod %s gave %s expected %s\n",
- line, as, bs, cs, ds);
- fail = TRUE;
- sfree(as);
- sfree(bs);
- sfree(cs);
- sfree(ds);
- }
- freebn(n);
- freebn(d);
- freebn(expect_q);
- freebn(expect_r);
- freebn(answer_q);
- freebn(answer_r);
- if (fail)
- fails++;
- else
- passes++;
- } else {
- printf("%d: unrecognised test keyword: '%s'\n", line, buf);
- exit(1);
- }
- sfree(buf);
- sfree(data);
- }
- printf("passed %d failed %d total %d\n", passes, fails, passes+fails);
- return fails != 0;
- }
|