#!/usr/bin/ruby

# Tests for the "Quadratic" class.

# Determine if a given number is probably a prime number.
func is_quadratic_pseudoprime (n, r=2) {

    return false if (n <= 1)
    return true  if (n <= 3)

    return true if (r <= 0)

    var x = Quadratic(r, 1, r+2).powmod(n, n)

    x.a == r || return false

    var y = Quadratic(r, -1, r+2).powmod(n, n)

    y.a == r || return false

    (x.b + y.b == n) && __FUNC__(n, r-1)
}

assert(is_quadratic_pseudoprime(43))
assert(is_quadratic_pseudoprime(97))

with (Quadratic(1, 1, 2)) {|q|
    assert_eq(
        15.of { q.pow(_).a }        #=> A001333
        [1, 1, 3, 7, 17, 41, 99, 239, 577, 1393, 3363, 8119, 19601, 47321, 114243]
    )
    assert_eq(
        15.of { q.pow(_).b }        #=> A000129
        [0, 1, 2, 5, 12, 29, 70, 169, 408, 985, 2378, 5741, 13860, 33461, 80782]
    )
}

with (Quadratic(1, 1, 3)) {|q|
    assert_eq(
        15.of { q.pow(_).a }        #=> A026150
        [1, 1, 4, 10, 28, 76, 208, 568, 1552, 4240, 11584, 31648, 86464, 236224, 645376]
    )
    assert_eq(
        15.of { q.pow(_).b }        #=> A002605
        [0, 1, 2, 6, 16, 44, 120, 328, 896, 2448, 6688, 18272, 49920, 136384, 372608]
    )
}

var n = (274177-1)
var m = (2**64 + 1)

with (Quadratic(3, 4, 2)) {|q|
    var r = q.powmod(n, m)
    assert_eq(gcd(r.a-1, m), 274177)
    assert_eq(gcd(r.b, m), 274177)
}

do {
    var a = Quadratic(5, 8, 10)
    var b = Quadratic(3, 9, 10)

    assert(a > b)
    assert(!(a < b))
    assert(b < a)
    assert(!(b > a))
    assert(a == a)
    assert(!(a != a))
    assert(b == b)
    assert(a != b)
    assert(b != a)
    assert(!(b == a))

    assert_eq(a, a)
    assert_eq(b, b)
    assert_ne(b, a)
    assert_ne(a, b)

    assert(a > 4)
    assert(a >= 5)
    assert(a < 6)

    assert_eq(a.a, 5)
    assert_eq(a.b, 8)
    assert_eq(a.w, 10)

    assert_eq(a+5, Quadratic(a.a+5, a.b, a.w))
    assert_eq(a-5, Quadratic(a.a-5, a.b, a.w))
    assert_eq(a*5, Quadratic(a.a*5, a.b*5, a.w))
    assert_eq(a/5, Quadratic(a.a/5, a.b/5, a.w))

    assert_eq(a+b, Quadratic(a.a + b.a, a.b + b.b, a.w))
    assert_eq(a*b, Quadratic(a.a*b.a + a.b*b.b*a.w, a.b*b.a + a.a*b.b, a.w))

    assert_eq(a.inv, a**(-1))
    assert_eq(b.inv, b**(-1))

    assert_eq(a.invmod(43), a.powmod(-1, 43))
    assert_eq(b.invmod(97), b.powmod(-1, 97))
    assert_eq(b.invmod(146).sqr.mod(146), b.powmod(-2, 146))
    assert_eq(b.sqr.invmod(146), b.powmod(-2, 146))

    assert_eq(a+b -> to_n.round(-30), a.to_n+b.to_n -> round(-30))
    assert_eq(a-b -> to_n.round(-30), a.to_n-b.to_n -> round(-30))
    assert_eq(a*b -> to_n.round(-30), a.to_n*b.to_n -> round(-30))
    assert_eq(a/b -> to_n.round(-30), a.to_n/b.to_n -> round(-30))
}

func Gaussian(a,b=0) {
    Quadratic(a, b, -1)
}

var r = (-2 .. 2)

for a in (r), b in (r), c in (r), d in (r) {

    assert_eq([Gaussian(a,b) + Gaussian(c,d) -> reals], [cadd(a,b,c,d)])
    assert_eq([Gaussian(a,b) - Gaussian(c,d) -> reals], [csub(a,b,c,d)])
    assert_eq([Gaussian(a,b) * Gaussian(c,d) -> reals], [cmul(a,b,c,d)])

    if (c*c + d*d != 0) {
        assert_eq([Gaussian(a,b) / Gaussian(c,d) -> reals], [cdiv(a,b,c,d)])
    }
}

for a in (r), b in (r) {

    var n = irand(0, 100)
    var m = irand(100, 1000)

    assert_eq([Gaussian(a,b)**n -> reals], [cpow(a,b,n)])
    assert_eq([Gaussian(a,b)**n -> reals].map { .mod(m) } , [cpowmod(a,b,n,m)])
}

func gaussian_sum(n) {

    var total = [0, 0]

    for k in (1..n) {
        total = [cadd(total..., cdiv(cpow(0, 1, k-1), k))]
    }

    [cmul(total..., n!)]
}

var arr = 10.of(gaussian_sum)

assert_eq(arr.map{.head}, [0, 1, 2, 4, 16, 104, 624, 3648, 29184, 302976])
assert_eq(arr.map{.tail}, [0, 0, 1, 3, 6, 30, 300, 2100, 11760, 105840])

do {
    var m = 10001
    var a = 43
    var b = 97

    assert_eq(Gaussian(1,0) / Gaussian(a, b), Gaussian(a,b).inv)
    assert_eq(Mod(Gaussian(a, b), m).inv * Gaussian(a,b), Mod(Gaussian(1,0), m))
    assert_eq(Mod(Gaussian(a, b), m).inv * Gaussian(a,b) -> lift, Gaussian(1,0))
    assert_eq(Mod(Gaussian(a, b), m)**(-1) * Gaussian(a,b) -> lift, Gaussian(1,0))

    assert_eq([cdiv(1, 0, a, b)], [a / (a*a + b*b), -b / (a*a + b*b)])
    assert_eq([a * invmod(a*a + b*b, m), -b * invmod(a*a + b*b, m)].map{.mod(m)}, [complex_invmod(a, b, m)])
    assert_eq([cmod(cmul(a, b, complex_invmod(a, b, m)), m)], [1, 0])
}

assert(Gaussian(3,4) == Gaussian(3,4))
assert(!(Gaussian(3,4) == Gaussian(3,3)))
assert(!(Gaussian(3,3) == Gaussian(3,4)))

assert(Gaussian(3,3) != Gaussian(3,4))
assert(Gaussian(3,4) != Gaussian(3,3))
assert(!(Gaussian(3,4) != Gaussian(3,4)))

assert_eq(Gaussian(4,5) <=> Gaussian(3,4), 1)
assert_eq(Gaussian(3,4) <=> Gaussian(3,4), 0)
assert_eq(Gaussian(3,4) <=> Gaussian(3,5), -1)

assert(Gaussian(4,5) >  Gaussian(3,4))
assert(Gaussian(4,5) >  Gaussian(4,4))
assert(Gaussian(4,5) >= Gaussian(3,4))
assert(Gaussian(3,4) >= Gaussian(3,4))
assert(Gaussian(3,4) <= Gaussian(3,4))

assert(!(Gaussian(3,4) > Gaussian(3,5)))
assert(!(Gaussian(3,4) > Gaussian(4,5)))
assert(!(Gaussian(3,4) < Gaussian(3,4)))
assert(!(Gaussian(3,4) < Gaussian(3,1)))
assert(!(Gaussian(3,4) < Gaussian(2,1)))
assert(!(Gaussian(3,4) <= Gaussian(3,3)))
assert(!(Gaussian(2,1) >= Gaussian(2,2)))

for a in (r), b in (r) {
    var y = irand(2, 100)
    var x = Gaussian(a,b)
    assert_eq(x - floor(x/y)*y, x % y)
}

func gaussian_sum_2(n) {

    var i     = Gaussian(0, 1)
    var total = Gaussian(0)

    for k in (1..n) {
        total += (i**(k-1) / k)
    }

    total * n!
}

assert_eq(
    10.of(gaussian_sum_2),
    [Gaussian(0, 0), Gaussian(1, 0), Gaussian(2, 1), Gaussian(4, 3), Gaussian(16, 6), Gaussian(104, 30), Gaussian(624, 300), Gaussian(3648, 2100), Gaussian(29184, 11760), Gaussian(302976, 105840)]
)

assert_eq(powmod(Gaussian(3,4), 1000, 1e6), Gaussian(585313, 426784))
assert_eq([Mod(Gaussian(3,4), 1e6)**1000 -> lift.reals], [585313, 426784])
assert_eq(Mod(Gaussian(3,4), 1e6)**1000, Mod(Gaussian(585313, 426784), 1e6))
assert(Mod(Gaussian(3,4), 1e6)**1000 == Mod(Gaussian(585313, 426784), 1e6))

assert_eq(Mod(43, 97).to_n, 43)
assert_eq(Gaussian(3,4).to_n, Gaussian(3,4).to_n)
assert_eq(Gaussian(3,4).to_c, 3+4i)

assert_eq(Gaussian(42).invmod(2017), Gaussian(1969, 0))
assert_eq(Gaussian(3,4).invmod(2017), Gaussian(1291, 968))
assert_eq(Gaussian(91,23).invmod(2017), Gaussian(590, 405))
assert_eq(Gaussian(43, 99).invmod(2017), Gaussian(1709,1272))
assert_eq(Gaussian(43, 99).invmod(1234567), Gaussian(1019551, 667302))

assert_eq(Mod(Gaussian(42), 2017).inv, Mod(Gaussian(1969, 0), 2017))
assert_eq(Mod(Gaussian(3,4), 2017)**(-1), Mod(Gaussian(1291, 968), 2017))
assert_eq(Mod(Gaussian(91,23), 2017)**(-1), Mod(Gaussian(590, 405), 2017))
assert_eq(Mod(Gaussian(43, 99), 2017).inv, Mod(Gaussian(1709,1272), 2017))
assert_eq(Mod(Gaussian(43, 99), 1234567)**(-2), Mod(Gaussian(1019551, 667302)**2, 1234567))
assert_eq(Mod(Gaussian(43, 99), 1234567)**(-5), Mod(Gaussian(1019551, 667302)**5, 1234567))

assert_eq(powmod(Gaussian(43, 99), -4, 1234567), invmod(Gaussian(43, 99)**4 % 1234567, 1234567))
assert_eq(powmod(Gaussian(43, 99), -5, 1234567), invmod(Gaussian(43, 99)**5 % Gaussian(1234567), 1234567))
assert_eq(powmod(Gaussian(43, 99), -5, 1234567), invmod(Gaussian(43, 99)**5, 1234567))

assert_eq(powmod(Gaussian(43, 99), -4, 1234567), invmod(powmod(Gaussian(43, 99), 4, 1234567), 1234567))
assert_eq(powmod(Gaussian(43, 99), -5, 1234567), invmod(powmod(Gaussian(43, 99), 5, 1234567), 1234567))

assert_eq(Gaussian(43, 97)**(-5), (Gaussian(43,97)**5)**(-1))
assert_eq(Gaussian(43, 97)**(-5), (Gaussian(43,97)**5).inv)
assert_eq(Gaussian(43, 97)**(-5), (Gaussian(43,97).inv)**5)
assert_eq(Gaussian(43, 97)**(-5), (Gaussian(43,97)**(-1))**5)

assert_eq(Mod(Gaussian(43, 97), 1234567)**(-5),  Mod(Gaussian(43, 97), 1234567)**5 -> inv)
assert_eq(Mod(Gaussian(43, 97), 1234567)**(-5), (Mod(Gaussian(43, 97), 1234567)->inv)**5)

assert_eq(Mod(Gaussian(43, 97), 1234567)**(-1234),  Mod(Gaussian(43, 97), 1234567)**1234 -> inv)
assert_eq(Mod(Gaussian(43, 97), 1234567)**(-1234), (Mod(Gaussian(43, 97), 1234567)->inv)**1234)
assert_eq(Mod(Gaussian(3,4), 1234567)**1234 -> lift, powmod(Gaussian(3,4), 1234, 1234567))
assert_eq(Mod(Gaussian(3,4), 1234567)**-1234 -> lift, powmod(Gaussian(3,4), -1234, 1234567))

assert_eq(Gaussian(3/5,11/4)**(-27), Gaussian(3/5,11/4)**27 -> inv)
assert_eq(Gaussian(3/5,11/4)**(-27), Gaussian(3/5,11/4).inv**27)

assert_eq(Gaussian(Gaussian(3,4), Gaussian(17,19)).to_n.to_n, -16+21i)
assert_eq(Gaussian(Mod(13, 97), Mod(43, 97)).to_n.to_n, 13+43i)
assert_eq(Gaussian(Mod(13, 97), Mod(43, 97)), Gaussian(Mod(13, 97), Mod(43, 97)))
assert_eq(Mod(Gaussian(3, 4), 97)*1234, Mod(Gaussian(16, 86), 97))
assert_eq(Mod(Gaussian(3/4, 5/6), 1234567)**10 * Mod(Gaussian(3/4, 5/6), 1234567)**-10, Mod(Gaussian(1,0), 1234567))

assert_eq((Mod(Gaussian(43/3, 97/5), 127)**(-11) * Mod(Gaussian(43/3, 97/5), 127))**+9, Mod(Gaussian(52, 73), 127))
assert_eq((Mod(Gaussian(43/3, 97/5), 127)**(-11) * Mod(Gaussian(43/3, 97/5), 127))**-9, Mod(Gaussian(81, 89), 127))

assert_eq((3 + Gaussian(4, 5)), Gaussian(7, 5))
assert_eq((3 - Gaussian(4, 5)), Gaussian(-1, -5))
assert_eq((3 * Gaussian(4, 5)), Gaussian(12, 15))
assert_eq((3 / Gaussian(4, 5)), Gaussian(12/41, -15/41))

var params = [
    [3, 4, 5, 6],
    [3, 4, 5, -2],
    [3,-11, 7, 23],
    [-9, -4, -1, -4],
    [0, -4, 1, 1],
    [0, -1, 13, 12],
    [5, 1, 7, 1],
    [1, 3, 0, 1],
]

params.each_2d {|a,b,c,d|

    var m = (2**64 + 1)

    for n in (-274176, 274176) {

        var x = powmod(Gaussian(a*d, b*c), n, m)
        var y = powmod(b*d, -n, m)

        var r1 = (x * y)%m
        var r2 = powmod(Gaussian(a/b, c/d), n, m)
        var r3 = Mod(Gaussian(a/b, c/d), m)**n
        var r4 = Mod(Gaussian(a/b, c/d), m)**(-n)

        say "Gaussian(#{a}/#{b}, #{c}/#{d})^#{n} == #{r1} (mod m)"

        assert_eq(r1, r2)
        assert_eq(r3.lift, r1)

        assert(r1 == r2)
        assert(r1 == r3.lift)

        assert_eq(r3 * r4 -> lift, Gaussian(1, 0))

        if (r1 != Gaussian(1,0)) {
            assert_eq(gcd(r1.re-1, m), 274177)
            assert_eq(gcd(r2.re-1, m), 274177)
        }
    }
}

say "** Test passed!"