Dijkstra's algorithm (Scala)
From LiteratePrograms
Here is an example of Dijkstra's shortest path algorithm in Scala.
Contents |
Shortest path
This procedural implementation uses sets and dictionaries of types HashSet
and HashMap
from scala.collection.mutable
as imported in the test module.
The custom written WeightedDiGraph
has a Vertex
type parameter, which the Test module instantiates as String
.
The WeightedDiGraph
cost type is Float
but could be Int
, Double
.
<<shortestPath>>= /** * @param start * @param end * @return route as list of pairs( node, shortest_distance), not including start * Nil if end not reached ( unconnected graph) */ def shortestPath( start: V, end: V): List[Pair[V, Float]] = { if (!arcs.contains( start)) throw new MyException( "shortestPath: "+start+" not in graph edge origins") ; if (!vertices.contains( end)) throw new MyException( "shortestPath: "+end+" not in graph vertices") ; // get shortest-distances, predecessors, end result val Triple( dist, pred, endReached) = dijkstra( start, end) ; // build path from end based on predecessors var path: List[Pair[V, Float]] = Nil ; if (endReached) { var v = end ; while ( v != start) { path = Pair(v, dist(v)) :: path ; // iterate on predecessor v = pred( v) } } // return path path }
Dijkstra's algorithm
<<dijkstra>>= /** * @param start * @param end (end vertex or null) * @return Triple( distances, predecessors, endReached) */ def dijkstra( source: V, end: V) = { assume( arcs.contains( source), "source not in arcs origins") if (end != null) assume( vertices.contains( end), "end not in graph") // initialize val dist = new HashMap[V, Float] ; // distances val Q = new HashSet[V] ; // priority queue val Settled = new HashSet[V] ; // settled vertices val pred = new HashMap[V, V] ; // predecessors minimumDistVertexDefinition // start with source vertex dist += source -> 0F ; Q += source ; var endReached = false ; while (! Q.isEmpty && ! endReached) { // extract minimumDistVertex from Q, add to Settled ones val u = minimumDistVertex( Q) ; Q -= u ; Settled += u ; if (end != null) endReached = (u == end) ; // update neighbours distances // and add updated ones to Q if (! endReached) for( val v <- adjacents( u); ! Settled.contains( v)) { val vNewDist = dist( u) + cost(u, v) ; if ( ! dist.isDefinedAt( v) || vNewDist < dist(v)) { dist += v -> vNewDist ; pred += v -> u ; Q += v ; } } } // return distances, predecessors, endReached Triple( dist, pred, endReached) }
Minimum distance vertex from set.
Note that, in the assume precondition, instead of writing the usual anonymous function as actual parameter in Q.elements.forall( v => dist.isDefinedAt(v) ) we can specify dist.isDefinedAt without the parameter, which is a function expression of the required type f(V) => Boolean .
<<minimumDistVertexDefinition>>= def minimumDistVertex( Q: HashSet[V]): V = { assume( ! Q.isEmpty && Q.elements.forall( dist.isDefinedAt )) ; val iterator = Q.elements ; val w = iterator.next ; // first element, because Q is not empty // calculate and return iterator.foldLeft( w) {(u, v) => if (dist( u) <= dist( v)) u else v} }
Our class WeightedDiGraph
will implement a directed graph with cost info.
It will implement arcs as a dictionary of origins and adjacents collection, which in turn will be a dictionary of destinations with the cost of the arc.
<<WeightedDiGraph>>= class WeightedDiGraph[V]() { val arcs = new HashMap[ V, HashMap[ V, Float]] ; val vertices = new HashSet[ V] ; def adjacents( u: V) = arcs( u).keys ; def cost( u: V, v: V) = arcs(u)(v) ; def addArc( from: V, to: V, kost: Float) = { if (arcs.contains( from)) { val adjMap = arcs( from) ; adjMap += to -> kost } else { val adjMap = new HashMap[ V, Float] ; adjMap += to -> kost ; arcs += from -> adjMap } vertices += from ; vertices += to } shortestPath dijkstra }
Test module
We will load the XML test data file to a variable of type scala.xml.Node, then we will apply an XPath like operator \ that selects the Node childs of the specified element tag. A similar operator \\ could be used to select descendants with the specified tag as in XPath.
<<ShortestPathTest.scala>>= package test; // folder: test import scala.collection.mutable.{HashSet, HashMap} ; import scala.xml.XML ; object ShortestPathTest { class MyException( msg: String) extends java.lang.RuntimeException( msg) ; WeightedDiGraph def main( args: Array[String]) = { val roadMap = new WeightedDiGraph[ String] ; val test_data = XML.loadFile( "test_data.xml") ; for( val arc <- test_data \ "graph" \ "arc") { val from = arc.attribute("from").toString ; val to = arc.attribute("to").toString ; val cost = java.lang.Float.parseFloat( arc.attribute("cost").get.toString) ; roadMap.addArc( from, to, cost) ; roadMap.addArc( to, from, cost) ; } // for each test entry in test_data for( val test <- test_data \ "sources" \ "test") { val source = test.attribute("source").toString ; val dest = test.attribute("dest").toString ; Console.println("from " + source + "\n") ; try { val route = roadMap.shortestPath( source, dest) ; if (route == Nil) { Console.println( "No route to " + dest) } else for( val Pair( city, distance) <- route) { Console.println( city + ": " + distance) } } catch { case m: MyException => Console.println( "MyException: "+ m.getMessage) case e: Throwable => Console.println( e.getMessage) ; } Console.println( "--") ; } } }
Test data
Approx. roadmap distances.
Barcelona is my city, Lausanne (Switzerland) is the one that hosts the EPFL Polythecnics university where The Scala programming language team develops the language.
<<test_data.xml>>= <?xml version="1.0" encoding="ISO-8859-1"?> <test_data> <graph> <arc from="Barcelona" to="Narbonne" cost="250" /> <arc from="Narbonne" to="Marseille" cost="260" /> <arc from="Narbonne" to="Toulouse" cost="150" /> <arc from="Narbonne" to="Geneve" cost="550" /> <arc from="Marseille" to="Geneve" cost="470" /> <arc from="Toulouse" to="Paris" cost="680" /> <arc from="Toulouse" to="Geneve" cost="700" /> <arc from="Geneve" to="Paris" cost="540" /> <arc from="Geneve" to="Lausanne" cost="64" /> <arc from="Lausanne" to="Paris" cost="536" /> </graph> <sources> <test source="Barcelona" dest="Lausanne"/> <test source="Lausanne" dest="Barcelona"/> </sources> </test_data>
Testing
Compile and run
scalac test/ShortestPathTest.scala scala test.ShortestPathTest
will produce the output
from Barcelona Narbonne: 250.0 Geneve: 800.0 Lausanne: 864.0 -- from Lausanne Geneve: 64.0 Narbonne: 614.0 Barcelona: 864.0 --
Download code |