Machine Learning with Clojure and Spark using Flambo

clojure spark ml

A tutorial on training a Logistic Regression classifier on Spark using Clojure.

In this short tutorial I’m going to show how to train a logistic regression classifier in a scalable manner with Apache Spark and Clojure using Flambo.


The goal of the tutorial is to help you familiarize yourself with Flambo – a Clojure DSL for Apache Spark. Even though Flambo is far from being complete, it already does a decent job of wrapping basic Spark APIs into idiomatic Clojure.

During the course of the tutorial, we are going to train a classifier capable of predicting whether a wine would taste good given certain objective chemical characteristics.

Step 1. Create a New Project

Run these commands:

$ lein new app t01spark
$ cd t01spark

Here, t01spark is the name of the project. You can give it any name you’d like. Don’t forger to change the current directory to the project you’ve just created.

Step 2. Update project.clj

Open project.clj in a text editor and update the dependency section so it looks like this:

    [[org.clojure/clojure "1.6.0"]
     [yieldbot/flambo "0.6.0"]
     [org.apache.spark/spark-mllib_2.10 "1.3.0"]]

Please note that although listing Spark jars in this manner is perfectly fine for exploratory projects, it is not suitable for production use. For that you will need to list them as “provided” dependencies in the profiles section, but let’s keep things simple for now.

Make sure that AOT is enabled, otherwise you will see strange ClassNotFound errors. Add this to the project file:

:aot :all

It also could make sense to add some extra memory for Spark:

:jvm-opts ^:replace ["-server" "-Xmx2g"]

Step 3. Download the Dataset

In this tutorial we are going to use the Wine Quality Dataset. Download and save it along with the project.clj file:

$ wget

Step 4. Start the REPL

The simplest way or running the Clojure REPL is Leiningen’s `repl` command:

$ lein repl
Clojure 1.6.0
Java HotSpot(TM) 64-Bit Server VM 1.8.0_xxx

Of course, nothing prevents you from running REPL in Emacs with Cider, IntelliJ IDEA or any other Clojure-aware IDE.

Step 5. Require Modules and Import Classes

user=> (require '[flambo.api :as f]
                '[flambo.conf :as cf]
                '[flambo.tuple :as ft]
                '[clojure.string :as s])

user=> (import '[org.apache.spark.mllib.linalg Vectors]
               '[org.apache.spark.mllib.regression LabeledPoint]
               '[org.apache.spark.mllib.classification LogisticRegressionWithLBFGS]
               '[org.apache.spark.mllib.evaluation BinaryClassificationMetrics])

Step 6. Create a Spark Context

user=> (def spark
         (let [cfg (-> (cf/spark-conf)
                       (cf/master "local[2]")
                       (cf/app-name "t01spark")
                       (cf/set "spark.akka.timeout" "300"))]
           (f/spark-context cfg)))

We’ve just created a Spark context bound to a local, in-process Spark server. You should see lots of INFO log messages in the terminal. That’s normal. Again, creating a Spark context like this will work for tutorial purposes, although in real life you’d probably want to wrap this expression into a memoizing function and call it whenever you need a context.

Step 7. Load and Parse Data

The data is stored in a CSV file with a header. We don’t need that header. To get rid of it, let’s enumerate rows and retain only those with indexes greater than zero. Then we split each row by the semicolon character and convert each element to float:

user=> (def data
         ;; Read lines from file
         (-> (f/text-file spark "winequality-red.csv")
             ;; Enumerate lines.
             ;; This function is missing from Flambo,
             ;; so we call the method directly
             ;; This is here purely for convenience:
             ;; it transforms Spark tuples into Clojure vectors
             (f/map f/untuple)
             ;; Get rid of the header
             (f/filter (f/fn [[line idx]] (< 0 idx)))
             ;; Split lines and transform values
             (f/map (f/fn [[line _]]
                      (->> (s/split line #";")
                           (map #(Float/parseFloat %)))))))

Let’s verify what’s in the RDD:

user=> (f/take data 3)
[(7.4 0.7 0.0 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5.0)
 (7.8 0.88 0.0 2.6 0.098 25.0 67.0 0.9968 3.2 0.68 9.8 5.0)
 (7.8 0.76 0.04 2.3 0.092 15.0 54.0 0.997 3.26 0.65 9.8 5.0)]

Looks legit.

Step 8. Transform the Data

The subjective wine quality information is contained in the Quality variable. It takes values in the [0..10] range. Let’s transform that into a binary variable by splitting it over the median. Wines with Quality below 6 will be considered “not good”, 6 and above - “good”.

I explored this dataset in R and found that the most interesting variables are Citric Acid, Total Sulfur Dioxide and Alcohol. I encourage you to experiment with adding other variables to the model. Also, using logarithms of those variables instead of raw values might be a good idea. Please refer to the Wine Quality Dataset documentation for a full variable list.

user=> (def dataset
         (f/map data
                (f/fn [[_ _ citric-acid _ _ _
                        total-sulfur-dioxide _ _ _
                        alcohol quality]]
                  ;; A wine is "good" if the quality is above the median
                  (let [good (if (<= 6 quality) 0.0 1.0)
                        ;; these will be our predictors
                        pred (double-array [citric-acid
                    ;; Spark requires samples to be packed into LabeledPoints
                    (LabeledPoint. good (Vectors/dense pred))))))

user=> (f/take dataset 3)
[#<LabeledPoint (1.0,[0.0,34.0,9.399999618530273])>
 #<LabeledPoint (1.0,[0.0,67.0,9.800000190734863])>
 #<LabeledPoint (1.0,[0.03999999910593033,54.0,9.800000190734863])>]

There is no order guarantee in derived RDDs, so you might get a different result.

Step 9. Prepare Training and Validation Datasets

user=> (f/cache dataset) ; Temporary cache the source dataset
                         ; BTW, caching is a side effect

user=> (def training
         (-> (f/sample dataset false 0.8 1234)

user=> (def validation
         (-> (.subtract dataset training)

user=> (map f/count [training validation]) ; Check the counts
(1291 235)

user=> (.unpersist dataset) ; no need to cache it anymore

Caching is crucial for MLlib performance. Actually, Spark MLlib algorithms will complain if you feed them with uncached datasets.

Step 10. Train a Classifier

MLlib-related parts are completely missing from Flambo, but that’s hopefully coming soon. For now, let’s use the Java API directly.

user=> (def classifier
         (doto (LogisticRegressionWithLBFGS.)
           ;; Otherwise we'll need to provide it
           (.setIntercept true)))

user=> (def model
         (doto (.run classifier (.rdd training))
           ;; We need the "raw" probability predictions

user=> [(.intercept model) (.weights model)]
 #<DenseVector [-1.6766504448212323,0.011619041367225583,-0.9683045663615859]>]

Step 11. Assess Predictive Power

First, let’s create a function to compute the area under the precision-recall curve and the area under the receiver operating characteristic curve. These are the most important indicators of the predictive power of a trained classification model.

user=> (defn metrics [ds model]
         ;; Here we construct an RDD containing [prediction, label]
         ;; tuples and compute classification metrics.
         (let [pl (f/map ds (f/fn [point]
                              (let [y (.label point)
                                    x (.features point)]
                                (ft/tuple (.predict model x) y))))
               metrics (BinaryClassificationMetrics. (.rdd pl))]
           [(.areaUnderROC metrics)
            (.areaUnderPR metrics)]))

Obtain metrics for the training dataset:

 user=> (metrics training model)
 [0.7800174890996763 0.7471259498290513]
And then for the validation dataset:

 user=> (metrics validation model)
 [0.7785138248847928 0.7160113864756078]

Slightly overfitting. OK, let’s turn L2 regularization on and rebuild the model:

user=> (doto (.optimizer classifier)
         (.setRegParam 0.0001))

user=> (def model
         (doto (.run classifier (.rdd training))

user=> (metrics training model)
[0.7794660966515655 0.748073583460006]

user=> (metrics validation model)
[0.7807459677419355 0.7200550175610565]

Looks good? I’m sure you can do better.

Step 12. Build a Predictor Function

As a final step, let’s define a function that we could use for predicting wine quality:

user=> (defn is-good? [model citric-acid
                       total-sulfur-dioxide alcohol]
         (let [point (-> (double-array [citric-acid
               prob (.predict model point)]
           (< 0.5 prob)))

user=> (is-good? model 0.0 34.0 9.399999618530273)


We have built a simple logistic regression classifier in Clojure on Apache Spark using Flambo. Some parts of the Flambo API are still missing, but it’s definitely usable. It was not terribly difficult to get it working and I hope you had fun.