Dijkstra's algorithm (Scala)

From LiteratePrograms

Jump to: navigation, search
Other implementations: C++ | Inform 7 | Java | Scala

Here is an example of Dijkstra's shortest path algorithm in Scala.

Contents

Shortest path

This procedural implementation uses sets and dictionaries of types HashSet and HashMap from scala.collection.mutable as imported in the test module.

The custom written WeightedDiGraph has a Vertex type parameter, which the Test module instantiates as String.

The WeightedDiGraph cost type is Float but could be Int, Double.

<<shortestPath>>=
/**
*  @param start
*  @param end
*  @return  route as list of pairs( node, shortest_distance), not including start
*           Nil if end not reached ( unconnected graph)
*/
def shortestPath( start: V, end: V): List[Pair[V, Float]] = {
  if (!arcs.contains( start)) 
    throw new MyException( "shortestPath: "+start+" not in graph edge origins") ;
  if (!vertices.contains( end)) 
    throw new MyException( "shortestPath: "+end+" not in graph vertices") ;
  // get shortest-distances, predecessors, end result
  val Triple( dist, pred, endReached) = dijkstra( start, end) ;
  // build path from end based on predecessors
  var path: List[Pair[V, Float]] = Nil ;
  if (endReached) {
    var v = end ;
    while ( v != start) {
      path = Pair(v, dist(v)) :: path ;
      // iterate on predecessor      
      v = pred( v)
    }
  }
  // return path
  path
}

Dijkstra's algorithm

<<dijkstra>>=
/**
*  @param start
*  @param end  (end vertex or null)
*  @return  Triple( distances, predecessors, endReached)
*/
def dijkstra( source: V, end: V) = {
  assume( arcs.contains( source), "source not in arcs origins")
  if (end != null) assume( vertices.contains( end), "end not in graph")
  // initialize
  val dist = new HashMap[V, Float] ;  // distances
  val Q = new HashSet[V] ;            // priority queue
  val Settled = new HashSet[V] ; // settled vertices
  val pred = new HashMap[V, V] ;    // predecessors
  minimumDistVertexDefinition
  // start with source vertex
  dist += source -> 0F ;
  Q += source ;
  var endReached = false ;
  while (! Q.isEmpty && ! endReached) {
    // extract minimumDistVertex from Q, add to Settled ones
    val u = minimumDistVertex( Q) ;
    Q -= u ;
    Settled += u ;
    if (end != null) endReached = (u == end) ;
    // update neighbours distances
    // and add updated ones to Q
    if (! endReached)
      for( val v <- adjacents( u); ! Settled.contains( v)) {
        val vNewDist = dist( u) + cost(u, v) ;
        if ( ! dist.isDefinedAt( v) || vNewDist < dist(v)) {
          dist += v -> vNewDist ;
          pred += v -> u ;
          Q += v ;
        }
      }
  }
  // return distances, predecessors, endReached
  Triple( dist, pred, endReached)
}

Minimum distance vertex from set.

Note that, in the assume precondition, instead of writing the usual anonymous function as actual parameter in Q.elements.forall( v => dist.isDefinedAt(v) ) we can specify dist.isDefinedAt without the parameter, which is a function expression of the required type f(V) => Boolean .

<<minimumDistVertexDefinition>>= 
def minimumDistVertex( Q: HashSet[V]): V = {
  assume( ! Q.isEmpty && Q.elements.forall( dist.isDefinedAt )) ;
  val iterator = Q.elements ;
  val w = iterator.next ;          // first element, because Q is not empty
  // calculate and return
  iterator.foldLeft( w) {(u, v) => if (dist( u) <= dist( v)) u else v}
}

Our class WeightedDiGraph will implement a directed graph with cost info.

It will implement arcs as a dictionary of origins and adjacents collection, which in turn will be a dictionary of destinations with the cost of the arc.

<<WeightedDiGraph>>=
class WeightedDiGraph[V]() { 
  val arcs = new HashMap[ V, HashMap[ V, Float]] ;
  val vertices = new HashSet[ V] ;
  def adjacents( u: V) = arcs( u).keys ;
  def cost( u: V, v: V) = arcs(u)(v) ;
  def addArc( from: V, to: V, kost: Float) = {
    if (arcs.contains( from)) {
        val adjMap = arcs( from) ;
        adjMap += to -> kost
    }
    else {
      val adjMap = new HashMap[ V, Float] ;
      adjMap += to -> kost ;
      arcs += from -> adjMap
    }
    vertices += from ;
    vertices += to 
  }
  shortestPath
  dijkstra
}

Test module

We will load the XML test data file to a variable of type scala.xml.Node, then we will apply an XPath like operator \ that selects the Node childs of the specified element tag. A similar operator \\ could be used to select descendants with the specified tag as in XPath.

<<ShortestPathTest.scala>>=
package test;                  // folder: test 
import scala.collection.mutable.{HashSet, HashMap} ;
import scala.xml.XML ;
object ShortestPathTest {
  class MyException( msg: String) extends java.lang.RuntimeException( msg) ;
  WeightedDiGraph
  def main( args: Array[String]) = {
      val roadMap = new WeightedDiGraph[ String] ;
      val test_data = XML.loadFile( "test_data.xml") ;
      for( val arc <- test_data \ "graph" \ "arc") {
        val from = arc.attribute("from").toString ;
        val to = arc.attribute("to").toString ;
        val cost = java.lang.Float.parseFloat( arc.attribute("cost").get.toString) ;
        roadMap.addArc( from, to, cost) ;
        roadMap.addArc( to, from, cost) ;
      }
      // for each test entry in test_data
      for( val test <- test_data \ "sources" \ "test") {
        val source = test.attribute("source").toString ;
        val dest = test.attribute("dest").toString ;
        Console.println("from " + source + "\n") ;
        try {
          val route = roadMap.shortestPath( source, dest) ;
          if (route == Nil) {
            Console.println( "No route to " + dest)
          }
          else for( val Pair( city, distance) <- route) {
            Console.println( city + ": " + distance)
          }
        }
        catch {
          case m: MyException => Console.println( "MyException: "+ m.getMessage)
          case e: Throwable => Console.println( e.getMessage) ;
        }
        Console.println( "--") ;  
      }
  }
}

Test data

Enlarge
Example of roadmap distance

Approx. roadmap distances.

Barcelona is my city, Lausanne (Switzerland) is the one that hosts the EPFL Polythecnics university where The Scala programming language team develops the language.

<<test_data.xml>>=
<?xml version="1.0" encoding="ISO-8859-1"?>
<test_data>
	<graph>
		<arc from="Barcelona" to="Narbonne" cost="250" />
		<arc from="Narbonne" to="Marseille" cost="260" />
		<arc from="Narbonne" to="Toulouse" cost="150" />
		<arc from="Narbonne" to="Geneve" cost="550" />
		<arc from="Marseille" to="Geneve" cost="470" />
		<arc from="Toulouse" to="Paris" cost="680" />
		<arc from="Toulouse" to="Geneve" cost="700" />
		<arc from="Geneve" to="Paris" cost="540" />
		<arc from="Geneve" to="Lausanne" cost="64" />
		<arc from="Lausanne" to="Paris" cost="536" />
	</graph>
	<sources>
		<test source="Barcelona" dest="Lausanne"/>
		<test source="Lausanne" dest="Barcelona"/>
	</sources>
</test_data>

Testing

Enlarge
The shortest path between Barcelona and Lausanne

Compile and run

scalac test/ShortestPathTest.scala
scala test.ShortestPathTest

will produce the output

from Barcelona
Narbonne: 250.0
Geneve: 800.0
Lausanne: 864.0
--
from Lausanne
Geneve: 64.0
Narbonne: 614.0
Barcelona: 864.0
--
Download code
Views