/*		Tests for functions mod powers of 2

This file contains tests for the special routines ModByPowerOf2(),
ShiftRight(), and ShiftLeft().  These interact with integers at a
low-level in order to provide improved performance for these special
cases of mod, div, and *.


Some tests care about what numbers are too large to represent.  Big
integers in Magma are represented with a header word containing the
signed length, and then an array of "digits" containing the data.
The header word uses one bit for the sign, and the rest gives the length
of the data.  Thus the maximum representable value is (2^D)^(2^L - 1) - 1,
where D is the number of bits in a "digit" and L is the number of bits in
the length (so L is one fewer than the number of bits in the header word).

From the above, a value of size at least 2^(D*2^L) is definitely not
representable.  In terms of bits that means we can only represent values
with fewer than D*2^L bits, equivalently 2^M bits where M = log2(D) + L.

Each quantity could be either 32-bit or 64-bit, so we have four
possibilities:

    D = 32, L = 31:  M = 36
    D = 64, L = 31:  M = 37
    D = 32, L = 63:  M = 68
    D = 64, L = 63:  M = 69

Thus we use M = 69 below for "too large to represent".  It would be
nice to use M = 37 in the development setup, but I cannot see an easy
way to determine (from the Magma language level) if this is appropriate.
*/
M_UNREP := 69;
EXPECTED_ERROR := "Integer too large to represent";

function base_error(errmsg)
    ok, _, matches := Regexp(": ([^:\n]*)\n?$", errmsg);
    if not ok then return errmsg; end if;
    return matches[1];
end function;


// We can't currently do 2^2^30 (exponent is required to be small), so
// we jump through some hoops to get larger powers of two.

t30 := (2^2^29)^2;

function safe_power_two(b)
    q,r := Quotrem(b, 2^30);
    y := 2^r;
    if q ne 0 then
	y *:= t30^q;
    end if;
    return y;
end function;

function short_string(x)
    if Abs(x) lt 10^5 then return IntegerToString(x); end if;

    // Try to write x as a*2^b + c with b large and a, c ideally small.
    // We'll limit |c| to 2^8; checking mod 2^10 lets us easily handle the
    // determination of the best b under this constraint.
    m := x mod 2^10;
    if m le 2^8 then
	c := m;
    elif m ge 2^10 - 2^8 then
	c := m - 2^10;
    else
	c := m - 2^9;
    end if;
    b, a := Valuation(x - c, 2);

    // If a is too large then use a truncated printing of it.  This may
    // be painful to debug later, admittedly.  We take a *very* arbitrary
    // cutoff of 2^10 for a.
    if Abs(a) gt 2^10 then
	digs := IntegerToString(x);
	if #digs le 24 then return digs; end if;
//	return Sprintf("%o...%o", Substring(digs, 1, 10), Substring(digs, #x-9, 10));
	return Sprintf("%o...%o", digs[[1..10]], digs[[#digs-9..#digs]]);
    end if;

    if a eq 1 then
	acomp := "";
    elif a eq -1 then
	acomp := "-";
    else
	acomp := Sprintf("%o*", a);
    end if;

    bcomp := Sprintf("2^%o", b);

    if c gt 0 then
	ccomp := Sprintf(" + %o", c);
    elif c lt 0 then
	ccomp := Sprintf(" - %o", -c);
    else
	ccomp := "";
    end if;

    return acomp cat bcomp cat ccomp;
end function;


//			ModByPowerOf2

procedure test_mod(t)
    x, b := Explode(t);
    u1 := ModByPowerOf2(x, b);
    u2 := x mod safe_power_two(b);
    if u1 ne u2 then
	printf "** ModByPowerOf2 mismatch!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	printf "   ModByPowerOf2(x, b): %o\n", short_string(u1);
	printf "   x mod b: %o\n", short_string(u2);
	error "ModByPowerOf2 failure";
    end if;
end procedure;

procedure test_mod_unrep_b(t)
    x, b := Explode(t);
    u := ModByPowerOf2(x, b);
    assert Sign(x) ge 0;
    if u ne x then
	printf "** ModByPowerOf2 mismatch!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	printf "   ModByPowerOf2(x, b): %o\n", short_string(u);
	printf "   x mod b: %o\n", short_string(x);
	error "ModByPowerOf2 failure";
    end if;
end procedure;

procedure test_mod_should_error(t)
    x, b := Explode(t);

    try
	u := ModByPowerOf2(x, b);
    catch E
	errmsg := base_error(E`Object);
	if errmsg eq EXPECTED_ERROR then return; end if;

	printf "** ModByPowerOf2 produced unexpected error!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	error E;
    end try;

    printf "** ModByPowerOf2 succeeded but should have failed!\n";
    printf "   x = %o\n", short_string(x);
    printf "   b = %o\n", short_string(b);
    error "Test should not have succeeded";
end procedure;


// These tests should succeed
modtests := [
    // x = 0, b = 0
    <0, 0>,

    // x = 0, b = anything
    <0, 1>, <0, 29>, <0, 30>, <0, 32>, <0, 62>, <0, 64>, <0, 2^30>,

    // x = anything, b = 0
    <1, 0>, <-1, 0>, <2^30, 0>, <-2^30, 0>,

    // x tiny, 2^b small
    <3, 5>, <-11, 7>, <843, 7>, <-2031, 8>,

    // x small, 2^b small
    <2^29 - 11, 29>, <-(2^29 - 13), 29>, <2^30 - 25, 29>, <-(2^30 - 143), 29>,

    // x large but single, 2^b small
    <2^31 + 245, 17>, <-(2^31 + 245), 17>,

    // x large, 2^b small
    <2^64 - 1, 23>, <-(2^64 - 1), 23>,

    // x small and positive, 2^b large
    <31415, 32>, <26535, 54>,

    // x small and negative, 2^b large
    <-4669, 201>, <-(2^29 - 1), 45>,

    // x positive, x < 2^b
    <2^64, 72>, <609*2^35, 80>, <2^103 - 7, 103>,

    // x positive, x = 2^b
    <2^64, 64>, <2^53, 53>, <2^1024, 1024>,

    // x positive, x > 2^b
    <327*2^72, 72>, <327*2^72, 75>,

    // x negative, x mod 2^b = 0
    <-327*2^72, 64>, <-327*2^72, 72>,

    // x negative, b large
    <-327*2^72, 88>, <-327*2^72, 34>, <-327*2^72, 256>
];

for t in modtests do
    test_mod(t);
end for;

delete modtests;

print "ModByPowerOf2: Main tests passed";


// These tests involve an unrepresentable 2^b, but should succeed
modtests := [
    // x positive, 2^b unrepresentable
    <13, 2^(M_UNREP + 3)>, <2^35 + 271, 2^(M_UNREP + 8)>
];

for t in modtests do
    test_mod_unrep_b(t);
end for;

delete modtests;

print "ModByPowerOf2: (positive, unrep) tests passed";


// These tests should raise an error
modtests := [
    // negative x with a b value large enough that 2^b - x is too big
    <-3, 2^M_UNREP>, <-2^33, 2^(M_UNREP + 3)>
];

for t in modtests do
    test_mod_should_error(t);
end for;

delete modtests;

print "ModByPowerOf2: (negative, unrep) tests passed";


//			ShiftRight

print "";

procedure test_div(t)
    x, b := Explode(t);
    u1 := ShiftRight(x, b);
    u2 := x div safe_power_two(b);
    if u1 ne u2 then
	printf "** ShiftRight mismatch!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	printf "   ShiftRight(x, b): %o\n", short_string(u1);
	printf "   x div 2^b: %o\n", short_string(u2);
	error "ShiftRight failure";
    end if;
end procedure;

procedure test_div_unrep_b(t)
    x, b := Explode(t);
    u1 := ShiftRight(x, b);
    u2 := (x ge 0) select 0 else -1;
    if u1 ne u2 then
	printf "** ShiftRight mismatch!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	printf "   ShiftRight(x, b): %o\n", short_string(u1);
	printf "   x div 2^b: %o\n", short_string(u2);
	error "ShiftRight failure";
    end if;
end procedure;

// These tests should succeed
divtests := [
    // x = 0, b = 0
    <0, 0>,

    // x = 0, b = anything
    <0, 1>, <0, 29>, <0, 30>, <0, 32>, <0, 62>, <0, 64>, <0, 2^30>,

    // x = anything, b = 0
    <1, 0>, <-1, 0>, <2^30, 0>, <-2^30, 0>, <2^30 - 1, 0>, <-(2^30 - 1), 0>,

    // x tiny, 2^b small
    <3, 5>, <-11, 1>, <-11, 7>, <843, 7>, <-2031, 8>,

    // x small, 2^b small
    <2^29 - 11, 29>, <-(2^29 - 13), 29>, <2^30 - 24, 29>, <-(2^30 - 143), 29>,

    // x small, 2^b large
    <13, 90>, <-13, 90>, 

    // |x| < 2^b, fewer digits
    <2^31, 32>, <-2^31, 32>, <17*2^35, 90>, <-17*2^35, 90>,

    // |x| < 2^b, same number of digits
    <2^45, 46>, <-2^45, 46>, <2^64, 65>, <-2^64, 65>, <2^62, 63>, <-2^62, 63>,

    // x divisible by 2^b
    <3*2^42, 40>, <-3*2^42, 40>, <17*2^95, 33>, <-17*2^95, 33>,

    // quotient small
    <513*2^54, 57>, <-513*2^54, 57>, <9*2^312, 300>, <-9*2^312, 300>,

    // quotient large
    <5*2^234, 50>, <-5*2^234, 50>,

    // negative x with carries mattering
    <-(2^53 - 1), 40>, <-(2^73-1), 41>, <-(2^212 - 1), 94>
];

for t in divtests do
    test_div(t);
end for;

delete divtests;

print "ShiftRight: Main tests passed";


// These tests involve an unrepresentable 2^b, but should succeed
divtests := [
    // x anything, 2^b unrepresentable
    <3, 2^M_UNREP>, <-5, 2^M_UNREP>,
    <13, 2^(M_UNREP + 3)>, <-13, 2^(M_UNREP + 3)>,
    <2^35 + 271, 2^(M_UNREP + 8)>, <-(2^35 + 271), 2^(M_UNREP + 8)>
];

for t in divtests do
    test_div_unrep_b(t);
end for;

delete divtests;

print "ShiftRight: Unrep tests passed";


//			ShiftLeft

print "";

procedure test_mult(t)
    x, b := Explode(t);
    u1 := ShiftLeft(x, b);
    u2 := x*safe_power_two(b);
    if u1 ne u2 then
	printf "** ShiftLeft mismatch!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	printf "   ShiftLeft(x, b): %o\n", short_string(u1);
	printf "   x*2^b: %o\n", short_string(u2);
	error "ShiftLeft failure";
    end if;
end procedure;

procedure test_mult_should_error(t)
    x, b := Explode(t);

    try
	u := ShiftLeft(x, b);
    catch E
	errmsg := base_error(E`Object);
	if errmsg eq EXPECTED_ERROR then return; end if;

	printf "** ShiftLeft produced unexpected error!\n";
	printf "   x = %o\n", short_string(x);
	printf "   b = %o\n", short_string(b);
	error E;
    end try;

    printf "** ShiftLeft succeeded but should have failed!\n";
    printf "   x = %o\n", short_string(x);
    printf "   b = %o\n", short_string(b);
    error "Test should not have succeeded";
end procedure;


multtests := [
    // x = 0, b = 0
    <0, 0>,

    // x = 0, b = anything
    <0, 1>, <0, 29>, <0, 30>, <0, 32>, <0, 62>, <0, 64>, <0, 2^30>,

    // x = anything, b = 0
    <1, 0>, <-1, 0>, <2^30, 0>, <-2^30, 0>, <2^30 - 1, 0>, <-(2^30 - 1), 0>,

    // x small, result small
    <13, 2>, <-57, 4>, <1, 29>, <-1, 29>, <57, 23>, <-57, 23>,

    // x small, result single
    <7, 28>, <-7, 28>,

    // x small, result large, word-aligned shifts
    <7, 32>, <7, 128>, <-13, 64>, <-1, 192>,

    // x small, result large, unaligned shifts
    <7, 35>, <7, 153>, <-13, 83>, <-1, 189>,

    // x large, word-aligned shifts
    <2985474468, 64>, <-3457970359, 96>,

    // x large, unaligned shifts
    <13810449767208660180, 13>, <-13810449767208660180, 13>,
    <13810449767208660180, 37>, <-13810449767208660180, 37>,
    <13810449767208660180, 84>, <-13810449767208660180, 84>
];

for t in multtests do
    test_mult(t);
end for;

delete multtests;

print "ShiftLeft: Basic tests passed";


// These tests should raise an error
multtests := [
    // any x with unrepresentable 2^b
    <3, 2^M_UNREP>, <-3, 2^M_UNREP>,
    <11, 2^(M_UNREP + 4)>, <-9, 2^(M_UNREP + 4)>,

    // Result too large to represent even though 2^b barely is.
    // This test is flawed if our value of M_UNREP is too large (see note
    // at top).
    <313, 2^M_UNREP - 2>, <-313, 2^M_UNREP - 2>
];

for t in multtests do
    test_mult_should_error(t);
end for;

delete multtests;

print "ShiftLeft: Unrep tests passed";
