Miller-Rabin primality test (Scala)

Other implementations: C | C, GMP | Clojure | Groovy | Java | Python | Ruby | Scala

The Miller-Rabin primality test is a simple probabilistic algorithm for determining whether a number is prime or composite that is easy to implement. It proves compositeness of a number using the following formulas:

Suppose 0 < a < n is coprime to n (this is easy to test using the GCD). Write the number n−1 as , where d is odd. Then, provided that all of the following formulas hold, n is composite:

for all

If a is chosen uniformly at random and n is prime, these formulas hold with probability 1/4. Thus, repeating the test for k random choices of a gives a probability of 1 − 1 / 4k that the number is prime. Moreover, Gerhard Jaeschke showed that any 32-bit number can be deterministically tested for primality by trying only a=2, 7, and 61.

We will implement the test for arbitrary-precision integers, with Scala's BigInt type, which can be handled more easily than Java's BigInteger type.

When performing the Miller-Rabin pass, only a, n, d, and powers of a are large numbers; s and i are logarithmically smaller. We precompute the large integers 1 and n−1 for use in comparisons and squaring:

<<Miller-Rabin pass>>=
def miller_rabin_pass(a: BigInt, n: BigInt): Boolean = {
var d: BigInt = 0
var s: BigInt = 0
var a_to_power: BigInt = 0
var i: Int = 0
compute s and d
a_to_power = a.modPow(d, n)
if (a_to_power == 1) {
return true
}
for (i <- 1 to s.intValue) {
if (a_to_power == n - 1) {
return true
}
a_to_power = (a_to_power * a_to_power) % n
}
return (a_to_power == n - 1)
}

We've replaced the squaring using modular_exponent with a modular multiply of a_to_power by itself. Finally, we use bit shifts to compute d rapidly, and test just its least significant component to determine if it is odd:

<<compute s and d>>=
d = n - 1
s = 0
while (d % 2 == 0) {
d >>= 1
s += 1
}

Since our numbers are now arbitrarily large, we can no longer take advantage of Jaeschke's result to produce a deterministic implementation, but we can produce a highly accurate probabilistic implementation by simply running the test enough times to make 1/4k very small, say 20 times.

<<Miller-Rabin>>=
def miller_rabin(n: BigInt): Boolean = {
var k: Int = 20
for (i: Int <- 1 to k) {
var a: BigInt = 0
var rand: scala.util.Random = new scala.util.Random()
while (a == 0) {
a = new BigInt(new java.math.BigInteger("" + (rand.nextDouble()*n.doubleValue).toInt))
}
if (! miller_rabin_pass(a, n)) {
return false
}
}
return true
}

For convenience we ignore the special cases where n ≤ 2.

Finally, some test code:

<<MillerRabin.scala>>=
Miller-Rabin pass
Miller-Rabin
var n: BigInt = 0
var nbits: BigInt = 0
var p: BigInt = 0
if (args(0) == "test") {
n = new BigInt(new java.math.BigInteger(args(1)))
if (miller_rabin(n))
print("PRIME")
else
print("COMPOSITE")
}
else if (args(0) == "genprime") {
nbits = new BigInt(new java.math.BigInteger(args(1)))
var rand: java.util.Random = null
rand = new java.util.Random(System.currentTimeMillis())
p = new BigInt(new java.math.BigInteger(nbits.intValue(), rand))
while (!miller_rabin(p) || p % 2 == 0 || p % 3 == 0 || p % 5 == 0 || p % 7 == 0) {
rand = new java.util.Random(System.currentTimeMillis())
p = new BigInt(new java.math.BigInteger(nbits.intValue(), rand))
}
println(p)
}

We've augmented main() with the ability to generate a prime number of a specified number of bits. It does this by randomly selecting numbers of that size until it finds a prime one. We could use this to rapidly generate large primes for use in cryptography. Note that there is neither supported continue nor break in Scala's while</core> or <code>for loop block. Since most random values have small prime factors, we first test some of these to avoid an expensive test:

<<test for small factors>>=
rand = new java.util.Random(System.currentTimeMillis())
p = new BigInt(new java.math.BigInteger(nbits.intValue(), rand))
while (!miller_rabin(p) || p % 2 == 0 || p % 3 == 0 || p % 5 == 0 || p % 7 == 0) {
rand = new java.util.Random(System.currentTimeMillis())
p = new BigInt(new java.math.BigInteger(nbits.intValue(), rand))
}
println(p)

Here's some sample output:

\$ scala MillerRabin.scala test 516119616549881
PRIME
\$ scala MillerRabin.scala test 516119616549887
COMPOSITE
\$ scala MillerRabin.scala genprime 128
137754824349890671048760329327009610963
\$ scala MillerRabin.scala genprime 512
35322546362785232466089390512753369849545354435096342125169009992469446061001698\
19822947596123187748351497917229392517283160261743596143268326745544727849
\$ scala MillerRabin.groovy genprime 1500
14135574239255638273940288776111990386502708867557424837773738886845834621411786\
95258780680296411416405689850909043309196578806612666295351660827556035642309715\
46019369868738756554802120048859996557215815760674482493409593262493125763100813\
68156166868368851576234171267515711784418855252212173969695748721023374781918665\
46154006939558876855762108831079983131994949499899243208648727721111252390364144\
878481543638406375972014256048962532894748465344627