Portabilität für Deep-Learning-Modelle mit ONNX

Seite 3: Modellübersetzung mit ONNX

Inhaltsverzeichnis

Das Modell hat nun eine brauchbare Qualität, aber die konkrete Beschreibung der Netzwerkarchitektur und dessen Parameter sind abhängig vom Trainingsframework, also von PyTorch. Das Modell ließe sich nun aus PyTorch in die Produktion übertragen. Der Schritt lässt sich jedoch komfortabler mit TensorFlow erledigen, das mit TensorFlow Serving ein dediziertes, performantes und ausgereiftes Werkzeug für das Deployment von Deep-Learning-Modellen bietet. Für den Übergang zwischen den beiden Frameworks kommt ONNX ins Spiel, das als Übersetzer zwischen unterschiedlichen Deep-Learning-Frameworks dient und damit eine Flexibilität bei der Wahl der Tools bietet, die es erlaubt, das Beste aus mehreren Welten zu nutzen.

Amazon, Facebook, Microsoft und über 20 weitere Unternehmen unterstützen ONNX, das erstmals im Dezember 2017 vorgestellt wurde. Es erlaubt den Export und Import trainierter und untrainierter Modelle aus beziehungsweise in diverse Frameworks wie Caffe2, Microsoft Cognitive Toolkit, MXNet, CoreML, TensorFlow und PyTorch. Manche Frameworks integrieren ONNX nativ, für andere existieren Konnektoren, also dedizierte Import- oder Exportmodule (siehe Abbildung unten). Eine Übersicht der Frameworks sowie Tutorials findet sich auf der ONNX-Site.

Die Tabelle zeigt, welche der gängigen Frameworks den Import aus beziehungsweise Export zu ONNX anbieten.

Praktischerweise bietet PyTorch den Modellexport mit der Funktion torch.onnx.export nativ, sodass keine weiteren Pakete erforderlich sind:

example_input = torch.randn(1, 784)


torch.onnx.export(dnn_model,
example_input,
"../models/dnn_model_pt.onnx",
input_names=["flattened_rescaled_img_28x28"]
+ ["weight_1", "bias_1",
"weight_2", "bias_2",
"weight_3", "bias_3"],
output_names=["softmax_probabilities"],
verbose=True)

Die Parameter sind die Modellinstanz dnn_model, ein beispielhafter Input example_input und der Pfad für die Exportdatei. Input, Parameter und Output dürfen dedizierte Namen erhalten. Da der Modellexport durch Tracing erfolgt, benötigt der Export einen Dummy für den Input. Das Netzwerk propagiert diesen und nutzt die aufgerufenen Funktionen zum Erzeugen des ONNX-Graphen.

Die Ausgabe besteht aus einer binären Protobuf-Datei, die die Architektur und Parametrisierung des trainierten Modells umfasst. Ein Zielframework kann das allgemein definierte Modell anschließend importieren. Für das Bereitstellen mit TensorFlow Serving ist zunächst der Import in TensorFlow notwendig, um anschließend eine TensorFlow-spezifische Modelldatei für das Deployment zu erzeugen.

Der benötigte TensorFlow-Konnektor lässt sich über pip install onnx-tf installieren. Der Befehl onnx_tf.backend.prepare konvertiert das eingelesene ONNX-Modell in ein TensorFlow-Modell zum Speichern in einer Protobuf-Datei:

import onnx
from onnx_tf.backend import prepare
model_path = '../models/dnn_model_pt.onnx'

dnn_model_onnx = onnx.load(model_path)
dnn_model_tf = prepare(dnn_model_onnx, device='cpu')
dnn_model_tf.export_graph('../models/dnn_model_tf.pb')

Das eigentliche Deployment erfolgt mit GraphPipe, das unter anderem auf TensorFlow Serving aufsetzt. Dank dem Verwenden von Flatbuffers erreicht es deutliche Performancegewinne gegenüber der standardmäßigen TensorFlow-Serving-Implementierung.

Für das Bereitstellen des Modells in einem lokal verfügbaren Docker-Container existiert ein Docker-Image, das TensorFlow Serving enthält. Der Start erfolgt mit folgendem Befehl:

docker run -it --rm \
-v "$PWD/models:/models/" \
-p 9000:9000 \
sleepsonthefloor/graphpipe-tf:cpu \
--model=/models/dnn_model_tf.pb \
--listen=0.0.0.0:9000

Neben der in der vorletzten Zeile angegebenen TensorFlow-Modelldatei ist in der letzten Zeile der Port definiert, unter dem sich das Modell via REST abfragen lässt.

Die Ausgabe sollte in etwa folgendermaßen aussehen:

INFO[0000] Starting graphpipe-tf version 1.0.0.10.f235920 

...

INFO[0000] Using default inputs [flattened_rescaled_img_28x28:0]
INFO[0000] Using default outputs [Softmax:0]
INFO[0000] Listening on '0.0.0.0:9000'

Nach dem Deployment ist es Zeit, die Installation zu testen. Mit requests lassen sich REST-Aufrufe an den lokalen Server schicken. Im Folgenden dient dazu der über pip install graphpipe installierte GraphPipe-Client, der fünf Beispielbilder aus den Testdaten an den Server sendet:

from graphpipe import remote

n_test_instances = 5
n_test = x_test.shape[0]

for _ in range(n_test_instances):
idx = np.random.randint(n_test)
# flatten and normalize test image
x = x_test[idx].reshape(1, -1)/255
y = y_test[idx][0]
softmax_pred = remote.execute("http://127.0.0.1:9000", x)
pred_class = mapping[np.argmax(softmax_pred)]
true_class = mapping[y_test[idx][0]]
print("Predicted Label / True Label: {} == {} ? - {} !"
.format(pred_class, true_class,
(pred_class==true_class)))

Der Modellserver gibt die passenden Klassifikationen zurück:

Predicted Label / True Label: 5 == 5 ? - True !
Predicted Label / True Label: 8 == 8 ? - True !
Predicted Label / True Label: 9 == 9 ? - True !
Predicted Label / True Label: F == F ? - True !
Predicted Label / True Label: 3 == S ? - False !

Als zusätzliche Information liefert das Backend die Inferenz-Zeiten:

INFO[0379] Request for / took 194.5615ms
INFO[0379] Request for / took 1.4576ms
INFO[0379] Request for / took 1.5524ms
INFO[0379] Request for / took 4.4992ms
INFO[0379] Request for / took 6.5824ms