Perfekter Mix: Kotlin und Python kombinieren für Machine Learning

Seite 2: Modelle erstellen und trainieren mit KotlinDL

Inhaltsverzeichnis

Als zweiten Aspekt nach dem Erkunden von Daten geht es nun um das Erstellen von Modellarchitekturen und das Training von Modellen mit Kotlin. Die Voraussetzung dafür ist KotlinDL in der aktuellen Version 0.4.0. Die Codebasis ist Open Source und auf GitHub verfügbar.

KotlinDL beschreibt sich selbst, als High-Level-Deep-Learning-API, ist in Kotlin geschrieben und stark von der Python-Library Keras inspiriert. Unter der Haube verwendet es die TensorFlow Java-API und die ONNX-Runtime-API für Java. Dadurch ist auch ein einfaches Training auf Grafikkarten möglich, allerdings derzeit nur auf NVIDIA-Modellen.

KotlinDL bietet einfache APIs für das Training von Deep-Learning-Modellen von Grund auf. Auch der Import bereits trainierter KotlinDL-, Keras- und ONNX-Modelle für die Inferenz ist möglich. Importierte und vortrainierte Modelle lassen sich mit Transfer-Learning an eigene Aufgaben anpassen.

Das folgende Beispiel trainiert mit KotlinDL ein einfaches Modell, das Bilder in Kategorien einteilt. Der fertige Code dazu ist im folgenden Listing dargestellt:

import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import java.io.File

// Laden der Trainings- und Testdaten
val (train, test) = fashionMnist()

// Model Architektur
val net = Sequential.of(
    Input(28L,28L,1L),
    Flatten(),
    Dense(300),
    Dense(100),
    Dense(10)
)

with(net){ // this: net
    // Einstellungen fuer Training festlegen
    compile(
        optimizer = Adam(),
        loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
        metric = Metrics.ACCURACY
    )
    
    printSummary()
    
    // Training des Modells
    fit(
        dataset = train,
        epochs = 5,
        batchSize = 100
    )
    
    // Evaluierung des Modells mit den Testdatensatz
    val accuracy = evaluate(dataset = test, batchSize = 100)
        .metrics[Metrics.ACCURACY]
    
    println("Accuracy: $accuracy")
   
}

Listing 4: KotlinDL: Ein einfaches Klassifizierungsmodell

KotlinDL liefert auch einige Datensätze mit, so auch den Fashion MNIST-Datensatz von Zalando. Er enthält 70.000 Bilder von Kleidungsstücken aus zehn Kategorien. Über die Hilfsfunktion val (train, test) = fashionMnist() lässt er sich importieren und direkt in 60.000 Trainingsbeispiele sowie 10.000 Testbeispiele aufteilen.

Als Nächstes gilt es, die Architektur des neuronalen Netzes festzulegen. Hier bietet sich sequenzielle Architektur an, bei der die Daten alle Schichten der Reihe nach durchlaufen. Die Input-Daten haben eine Bildgröße von 28 mal 28 Pixeln: Das erfordert einen Input-Layer entsprechender Größe, festzulegen mit Input(28L, 28L, 1L).

Die zweidimensionalen Bilddaten lassen sich mit einem Flatten()-Layer in einen eindimensionalen Vektor der Länge 784 umwandeln, damit die folgenden Layer die Daten verarbeiten können.

Es folgen zwei vollverknüpfte Dense-Layer, in denen die eigentliche Lernleistung des Modells stattfindet. Am Ende steht noch ein Output-Layer (ebenfalls Dense), der zehn Outputs hat, was der Zahl der zu unterscheidenden Kategorien entspricht. Damit ist die einfache Modellarchitektur fertig.

Im nächsten Schritt gilt es, die Einstellungen zum Training des Modells festzulegen. Als Optimierungsalgorithmus steht der Adam-Optimizer bereit und die loss function ist auf SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS festzulegen. Die Metrik, die der Algorithmus optimieren soll, lautet metric = Metrics.ACCURACY.

Insgesamt sind das gewöhnliche Werte, mit denen in der ersten Iteration eines Klassifizierungsmodells nichts schiefgehen kann. Beim Training eines echten Machine-Learning-Modells wären jedoch genau auf die Aufgabe abgestimmte Parameter auszuwählen.

Das eigentliche Training findet mit dem Aufruf der fit-Funktion statt. Sie nimmt den Trainingsdatensatz zusammen mit der Anzahl der Iterationen über den Trainingsdatensatz (epochs) entgegen. Mit batchSize lässt sich festlegen, wie viele Beispiele auf einmal verwendbar sind, um die Parameter des Modells zu aktualisieren. Nach dem Training des Modells können Entwicklerinnen und Entwickler mit dem Testdatensatz und der evaluate-Funktion die Qualität der Modellvorhersage bestimmen.

Wie das Beispiel verdeutlich, lässt sich ein Machine-Learning-Modell bereits mit wenigen Zeilen Kotlin-Code trainieren und evaluieren. Dabei ist eine Vorhersagegenauigkeit von mehr als 85 Prozent erreichbar. In einem realen Projekt wäre das jedoch nur der erste Schritt in einer ganzen Reihe von Experimenten, um die optimale Vorhersagequalität eines Modells zu bestimmen.

Das trainierte Modell muss man in einen Webservice oder eine App verpacken, damit Nutzer damit interagieren können und es auf der Basis von Daten Vorhersagen treffen kann. Mit KotlinDL und dem Framework Ktor lässt sich in wenigen Zeilen ein ML-Modell-Server erstellen.

Ktor ist ein Kotlin-Framework auf Basis von Kotlin-Coroutines, das es ermöglicht, schlanke Webapplikationen und APIs zu bauen. Im Beispielprojekt (siehe folgendes Listing) geht es darum, ein Objekterkennungsmodell über eine REST API zur Verfügung zu stellen:

package modelserver

import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.serialization.jackson.*
import io.ktor.server.application.*
import io.ktor.server.routing.*
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import java.io.File
import java.util.*

fun main() {
  val modelHub = 
    ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
  val model =  
    ONNXModels.ObjectDetection.SSD.pretrainedModel(modelHub) //(1)

  embeddedServer(Netty, port = 8089) { // (2)
    install(ContentNegotiation) {
      jackson()
    }
    routing {
      post("/detect") { // (3)
        // Bilddatei empfangen und speichern (4)
        val multipartData = call.receiveMultipart()
        lateinit var imageFile: File
        val newFileName = UUID.randomUUID()
        multipartData.forEachPart { part ->
          when (part) {
            is PartData.FileItem -> {
              val fileName = part.originalFileName as String
              val fileBytes = part.streamProvider().readBytes()
              call.application.environment.log.info(
                "Received file: \"$fileName\"")
              imageFile=File("uploads/$newFileName")
              imageFile.writeBytes(fileBytes)
            }
            else -> TODO()
          }
        }
        // Inferenz: Objekte erkennen (5)
        val detectedObjects = 
        model.detectObjects(imageFile = imageFile, topK = 3)
        call.application.environment.log.info(
          "Found: $detectedObjects")
        // Erkannte Objekte zurueckgeben // (6)
        call.respond(detectedObjects)
      }
    }
  }.start(wait = true)
}

Listing 5: Model Serving mit Ktor und KotlinDL

Als Grundlage dient ein bereits fertig trainiertes Modell, das mithilfe von KotlinDL aus dem ONNXModelHub (1) zu laden ist. Einen Ktor-Server zu starten, ist kein Hexenwerk: Alles, was man dafür benötigt, ist folgender Code (2):

fun main() {
    embeddedServer(Netty, port = 8089) {
        //...
    }.start(wait=true)
}

In diesem Server-Grundgerüst lassen sich verschiedene REST-API-Endpunkte im Bereich routing festlegen. Hier gilt es, eine Route mit der post-Funktion zu definieren (3). Die Funktion kann Bilddaten entgegennehmen, sie verarbeiten und zur Inferenz an das Modell übergeben (4). Der Versand der Bilddaten erfolgt über Multipart-Formdata.

Dazu iteriert der Code über die gesamten Multipart-Daten und überprüft, ob das entsprechende Feld ein FileItem enthält. Ist das der Fall, ist der Dateiname auszulesen und für ein späteres Debugging zu loggen.

Die eigentlich Daten sind mit streamProvider().readBytes() auszulesen und in einer neuen Datei zu speichern, die eine eindeutige ID benötigt. Im Beispiel ist dafür UUID.randomUUID() als ID zuzuweisen.

Sind die Bilddaten aus der Anfrage eingelesen, kann das Modell eine Objekterkennung durchführen (5): Mit model.detectObjects(imageFile = imageFile, topK = 3) erhält es die gespeicherte Bilddatei zur Erkennung. Dabei legen die Zuständigen mit topK = 3 fest, dass sie nur an den drei wahrscheinlichsten Ergebnissen interessiert sind. Die erkannten Objekte lassen sich mit call.respond(detectedObjects) zurückgeben (6).

Auf diese Weise entsteht in nur wenigen Zeilen Code mit Ktor und KotlinDL ein einfacher Modellserver. Die Grundlage kann jedes fertig trainierte Keras- oder ONNX-Modell sein.

Für den produktiven Betrieb fehlen noch einige Schritte wie das Erfassen essenzieller Metriken, die Dauer der Inferenz und die Vorhersagequalität. Auch würde das Modell in einem produktionsreifen Modellserver nicht dynamisch vom Modellhub geladen, sondern aus einem statischen Ordner gelesen, in dem es seit dem Build-Prozess liegt.

Zurzeit ist Python die dominierende Sprache im Machine Learning. Die Frage, warum Entwickler und ML-Ingenieurinnen sich mit anderen Programmiersprachen beschäftigen sollten und wofür sich das lohnt, steht im Raum. Dabei ist es hilfreich, die Frage einmal umzudrehen: Warum ist Python so beliebt für Machine Learning?

Zum einen ist Python eine großartige Programmiersprache und hat seit ihrem Erscheinen im Jahr 1991 eine rasante Entwicklung durchgemacht. Vor allem ist rund um Python ein ausgereiftes Ökosystem entstanden – von Bibliotheken für wissenschaftliche Anwendungen über Vektor- und Matrix-Mathematik bis hin zum Data Processing.

Da Machine Learning neben dem tatsächlichen Einsatz für Geschäftsanwendungen zunächst auch ein umfangreiches Forschungsfeld ist, sind viele der Bibliotheken in einer Sprache geschrieben, die den Entwicklern dieser Bibliotheken aus dem wissenschaftlichen Umfeld vertraut war. So kam es zu einem sich selbst verstärkenden Zyklus.

Es gibt jedoch auch noch einen anderen Grund für die Beliebtheit von Python im Machine Learning und in der Data Science: Python ist eine großartige Skriptsprache. Python-Code sieht fast aus wie Pseudo-Code, aber er funktioniert. Bei der Diskussion rund um das Für und Wider verschiedener Programmiersprachen darf aber nicht aus dem Blick geraten, dass viele Data Scientists keine Softwareingenieure sind.

Im Job eines Data Scientists geht es nicht darum, eine Low-Level-Programmiersprache zu lernen, um optimalen, auf Performance getrimmten Code zu schreiben. Wenn es darum geht, ein kleines Skript zu bauen, um Daten zu analysieren, ist der Anspruch ein anderer. Ein Data Scientist muss tiefes Wissen in Statistik haben und Algorithmen entwickeln können: Dafür ist Python ideal. Es ist einfach zu nutzen und lässt sich zudem einsetzen, um auf High-Performance-Bibliotheken zuzugreifen, die unter anderem in C geschrieben sind wie NumPy und TensorFlow.

Es gibt jedoch andere Gründe, sich mit Kotlin für ML zu beschäftigen. Das Deployment in die Produktion ist nicht der Abschluss eines Machine-Learning-Projekts, sondern lediglich eine weitere Phase. Die meiste Zeit verbringen Machine Learning Engineers mit der Maintenance und dem Erweitern von Komponenten. Aber je mehr Code entsteht, desto unübersichtlicher wird das Ganze. Wenn die Projekte größer werden, kann ein starkes Typsystem helfen, den Überblick zu behalten. Python ist stark typisiert, aber auch dynamisch. Das hilft bei flexiblen ad-hoc-Analysen. Mit Python lassen sich ohne großen Aufwand Datenstrukturen zusammenbauen, die sogar noch während der Laufzeit definierbar sind.

Im Gegensatz dazu ist Kotlin statisch typisiert. Typen werden zur Compilezeit an Typen gebunden. Das bedeutet, dass viele Typfehler, die in Python zu einem Runtime Error führen könnten, in Kotlin bereits zur Compilezeit auffallen. Das hilft insbesondere Teams, in denen viele gemeinsam an großen Projekten arbeiten.

Ein weiteres Problem in Bezug auf die Wartung und Skalierbarkeit von Projekten ist die Reproduzierbarkeit von Builds und eine vernünftige Verwaltung der Dependencies. In Python sind Builds und Dependencies stark von der Laufzeitumgebung des Codes beeinflusst. Das macht es schwierig, eine Reproduzierbarkeit herzustellen, wenn man nicht das ganze Environment in Docker verpacken will. Mit Kotlin ist es etwas einfacher, stabile Dependency Trees zu bauen.

Was Kotlin Python außerdem voraus hat: Bewegt man sich in Python außerhalb der gekapselten C-Libraries, wird Python ausgesprochen langsam. Im Vergleich hat Kotlin deutlich die Nase vorn.