-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathSparkWithPMML.scala
More file actions
62 lines (49 loc) · 2.22 KB
/
SparkWithPMML.scala
File metadata and controls
62 lines (49 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.classification.LogisticRegression
import org.jpmml.sparkml.PMMLBuilder
import org.jpmml.model.JAXBUtil
import javax.xml.transform.stream.StreamResult
import java.io.File
import java.io.FileWriter
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.Pipeline
import org.jpmml.evaluator.Evaluator
import org.jpmml.evaluator.LoadingModelEvaluatorBuilder
import java.util.LinkedHashMap
import org.dmg.pmml.FieldName
import org.jpmml.evaluator.FieldValue
import org.jpmml.evaluator.EvaluatorUtil
import scala.collection.JavaConversions._
object SparkWithPMML {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("ScalaSparkML").setMaster("local")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("data/simple-ml.json")
val iris = df.schema
val Array(train, test) = df.randomSplit(Array(0.7, 0.3))
val supervised = new RFormula().setFormula("lab ~ . + color: value1 + color: value2")
.setLabelCol("label1")
.setFeaturesCol("features2")
val lr = new LogisticRegression().setLabelCol("label1").setFeaturesCol("features2")
var pipeline = new Pipeline().setStages(Array(supervised, lr))
var model = pipeline.fit(train)
val pmml = new PMMLBuilder(iris, model).build()
JAXBUtil.marshalPMML(pmml, new StreamResult(new File("data/model")))
var eval = new LoadingModelEvaluatorBuilder().load(new File("data/model")).build()
eval.verify()
var inputFields = eval.getInputFields()
val row = test.first()
val arguments = new LinkedHashMap[FieldName, FieldValue]
for (inputField <- inputFields) {
var fieldName = inputField.getFieldName()
var value = row.get(row.fieldIndex(fieldName.getValue()))
val inputValue = inputField.prepare(value)
arguments.put(fieldName, inputValue)
}
var result = eval.evaluate(arguments)
var resultRecoard = EvaluatorUtil.decodeAll(result)
println(resultRecoard)
}
}