Dijkstra's algorithm (Scala)
Here is an example of Dijkstra's shortest path algorithm in Scala.
Shortest path
This procedural implementation uses sets and dictionaries of types HashSet
and HashMap
from scala.collection.mutable
as imported in the test module.
The custom written WeightedDiGraph
has a Vertex
type parameter, which the Test module instantiates as String
The WeightedDiGraph
cost type is Float
but could be Int
, Double
<<shortestPath>>= /** * @param start * @param end * @return route as list of pairs( node, shortest_distance), not including start * Nil if end not reached ( unconnected graph) */ def shortestPath( start: V, end: V): List[Pair[V, Float]] = { if (!arcs.contains( start)) throw new MyException( "shortestPath: "+start+" not in graph edge origins") ; if (!vertices.contains( end)) throw new MyException( "shortestPath: "+end+" not in graph vertices") ; // get shortest-distances, predecessors, end result val Triple( dist, pred, endReached) = dijkstra( start, end) ; // build path from end based on predecessors var path: List[Pair[V, Float]] = Nil ; if (endReached) { var v = end ; while ( v != start) { path = Pair(v, dist(v)) :: path ; // iterate on predecessor v = pred( v) } } // return path path }
Dijkstra's algorithm
<<dijkstra>>= /** * @param start * @param end (end vertex or null) * @return Triple( distances, predecessors, endReached) */ def dijkstra( source: V, end: V) = { assume( arcs.contains( source), "source not in arcs origins") if (end != null) assume( vertices.contains( end), "end not in graph") // initialize val dist = new HashMap[V, Float] ; // distances val Q = new HashSet[V] ; // priority queue val Settled = new HashSet[V] ; // settled vertices val pred = new HashMap[V, V] ; // predecessors minimumDistVertexDefinition // start with source vertex dist += source -> 0F ; Q += source ; var endReached = false ; while (! Q.isEmpty && ! endReached) { // extract minimumDistVertex from Q, add to Settled ones val u = minimumDistVertex( Q) ; Q -= u ; Settled += u ; if (end != null) endReached = (u == end) ; // update neighbours distances // and add updated ones to Q if (! endReached) for( val v <- adjacents( u); ! Settled.contains( v)) { val vNewDist = dist( u) + cost(u, v) ; if ( ! dist.isDefinedAt( v) || vNewDist < dist(v)) { dist += v -> vNewDist ; pred += v -> u ; Q += v ; } } } // return distances, predecessors, endReached Triple( dist, pred, endReached) }
Minimum distance vertex from set.
Note that, in the assume precondition, instead of writing the usual anonymous function as actual parameter in Q.elements.forall( v => dist.isDefinedAt(v) ) we can specify dist.isDefinedAt without the parameter, which is a function expression of the required type f(V) => Boolean .
<<minimumDistVertexDefinition>>= def minimumDistVertex( Q: HashSet[V]): V = { assume( ! Q.isEmpty && Q.elements.forall( dist.isDefinedAt )) ; val iterator = Q.elements ; val w = iterator.next ; // first element, because Q is not empty // calculate and return iterator.foldLeft( w) {(u, v) => if (dist( u) <= dist( v)) u else v} }
Our class WeightedDiGraph
will implement a directed graph with cost info.
It will implement arcs as a dictionary of origins and adjacents collection, which in turn will be a dictionary of destinations with the cost of the arc.
<<WeightedDiGraph>>= class WeightedDiGraph[V]() { val arcs = new HashMap[ V, HashMap[ V, Float]] ; val vertices = new HashSet[ V] ; def adjacents( u: V) = arcs( u).keys ; def cost( u: V, v: V) = arcs(u)(v) ; def addArc( from: V, to: V, kost: Float) = { if (arcs.contains( from)) { val adjMap = arcs( from) ; adjMap += to -> kost } else { val adjMap = new HashMap[ V, Float] ; adjMap += to -> kost ; arcs += from -> adjMap } vertices += from ; vertices += to } shortestPath dijkstra }
Test module
We will load the XML test data file to a variable of type scala.xml.Node, then we will apply an XPath like operator \ that selects the Node childs of the specified element tag. A similar operator \\ could be used to select descendants with the specified tag as in XPath.
<<ShortestPathTest.scala>>= package test; // folder: test import scala.collection.mutable.{HashSet, HashMap} ; import scala.xml.XML ; object ShortestPathTest { class MyException( msg: String) extends java.lang.RuntimeException( msg) ; WeightedDiGraph def main( args: Array[String]) = { val roadMap = new WeightedDiGraph[ String] ; val test_data = XML.loadFile( "test_data.xml") ; for( val arc <- test_data \ "graph" \ "arc") { val from = arc.attribute("from").toString ; val to = arc.attribute("to").toString ; val cost = java.lang.Float.parseFloat( arc.attribute("cost").get.toString) ; roadMap.addArc( from, to, cost) ; roadMap.addArc( to, from, cost) ; } // for each test entry in test_data for( val test <- test_data \ "sources" \ "test") { val source = test.attribute("source").toString ; val dest = test.attribute("dest").toString ; Console.println("from " + source + "\n") ; try { val route = roadMap.shortestPath( source, dest) ; if (route == Nil) { Console.println( "No route to " + dest) } else for( val Pair( city, distance) <- route) { Console.println( city + ": " + distance) } } catch { case m: MyException => Console.println( "MyException: "+ m.getMessage) case e: Throwable => Console.println( e.getMessage) ; } Console.println( "--") ; } } }
Test data
Approx. roadmap distances.
Barcelona is my city, Lausanne (Switzerland) is the one that hosts the EPFL Polythecnics university where The Scala programming language team develops the language.
<<test_data.xml>>= <?xml version="1.0" encoding="ISO-8859-1"?> <test_data> <graph> <arc from="Barcelona" to="Narbonne" cost="250" /> <arc from="Narbonne" to="Marseille" cost="260" /> <arc from="Narbonne" to="Toulouse" cost="150" /> <arc from="Narbonne" to="Geneve" cost="550" /> <arc from="Marseille" to="Geneve" cost="470" /> <arc from="Toulouse" to="Paris" cost="680" /> <arc from="Toulouse" to="Geneve" cost="700" /> <arc from="Geneve" to="Paris" cost="540" /> <arc from="Geneve" to="Lausanne" cost="64" /> <arc from="Lausanne" to="Paris" cost="536" /> </graph> <sources> <test source="Barcelona" dest="Lausanne"/> <test source="Lausanne" dest="Barcelona"/> </sources> </test_data>
Compile and run
scalac test/ShortestPathTest.scala scala test.ShortestPathTest
will produce the output
from Barcelona Narbonne: 250.0 Geneve: 800.0 Lausanne: 864.0 -- from Lausanne Geneve: 64.0 Narbonne: 614.0 Barcelona: 864.0 --
