Gleitkommazahlen im Machine Learning: Weniger ist mehr für Intel, Nvidia und ARM
Die drei Hardwarehersteller haben einen gemeinsamen Vorstoß für das standardisierte 8-Bit-Format "FP8" für Gleitkommazahlen veröffentlicht.
ARM, Intel und Nvidia machen sich gemeinsam für ein Gleitkommazahlenformat mit 8 Bit stark. Der neue Standard FP8 soll die Zahlen definieren und die in IEEE 754 beschriebenen Formate für 16, 32 und 64 Bit ergänzen. Motivation ist der Einsatz im Training von Machine-Learning-Modellen, das trotz des schmaleren Datenformats ohne nennenswerte Einbußen bei der Genauigkeit deutlich schneller laufen soll.
Präzision, Exponenten und Mantissen
Die Norm IEEE 754 definiert den Aufbau von Gleitkommazahlen unterschiedlicher Präzision. Als Basis dient die einfache Präzision 32 Bit (FP32). Passend dazu benötigen Gleitkommazahlen doppelter Präzision 64 Bit (FP64) und für die halbe Präzision reichen 16 Bit (FP16). Daneben existieren sogenannte Minifloats mit geringerer Bit-Anzahl, die nicht in der IEEE-Norm definiert sind.
Einige Programmiersprachen kennen unterschiedliche Typen wie float
und double
in C und C++ sowie f32
und f64
in Rust. Andere Sprachen wie JavaScript und das im Data-Science- und Machine-Learning-Umfeld verbreitete Python setzen standardmäßig auf 64-Bit-Gleitkommazahlen.
Unabhängig von der Bit-Breite stellen alle Formate die Zahlen in einer Kombination aus Mantisse (m) und Exponent (e) dar. Erstere enthält die Ziffern, während Letztere den Exponenten definiert. Hinzu kommt die Basis (b), die bei IEEE 754 2 ist. Die tatsächlich Zahl ergibt sich durch die Formel x = m * 2e. Der Exponent ist vorzeichenbehaftet und kann beispielsweise bei einer Gleitkommazahl einfacher Präzision, die 23 Bit für die Mantisse und 8 für den Exponenten vorsieht, zwischen -126 und 127 liegen. Alle Präzisionsstufen verwenden ein Bit für das Vorzeichen der Zahl.
Weniger ist mehr
Während bei typischen Anwendungen die Breite der Fließkommazahlen keine nennenswerten Auswirkungen auf die Performance hat, fallen im Machine Learning gerade beim Trainieren von Modellen, aber auch bei der Inferenz im produktiven Einsatz unzählige Berechnungen parallel an, die typischerweise auf dedizierter Hardware wie GPUs oder KI-Beschleunigern stattfindet.
Daher versuchen diverse Ansätze seit geraumer Zeit die Gleitkommazahlenformate für Machine Learning zu optimieren. Viele setzen auf einen Mixed-Precision-Weg, der beispielsweise Gleitkommazahlen halber und einfacher Präzision, aber auch FP16 mit 8-Bit-Integer-Werten (INT8) mischt. Viele aktuelle Prozessoren und Rechenbeschleuniger verarbeiten auch BFloat16 (BF16).
Deutlich schneller, aber fast so exakt
Laut einem Paper auf arXiv leidet die Genauigkeit vieler Vision- und sprachbasierter Modelle kaum unter einer geringeren Präzision der Gleitkommazahlen. Es vergleicht den Einsatz von 8-Bit-Gleitkommazahlen (FP8) mit dem von FP16-Werten unter anderem bei den Sprach- beziehungsweise Transformer-Modellen BERT und GPT-3.
Deutliche Unterschiede zeigen sich dagegen bei der Performance: Nvidia hat mit dem Performance-Messwerkzeug MLPerf Inference 2.1 Messungen auf der Anfang 2022 eingeführten Hopper-Architektur durchgeführt. Die Verarbeitung erfolgte laut dem Unternehmen viereinhalbmal so schnell wie mit FP16.
Ein Standard für FP8
Um eine einheitliche Verarbeitung der 8-Bit-Gleitkommazahlen zu gewährleisten, wollen ARM, Intel und Nvidia gemeinsam einen Standardisierungsvorschlag vorantreiben. Das FP8-Format sieht zwei verschiedene Ausprägungen vor: E5M2 nutzt zwei Bits für die Mantisse und fünf für den Exponenten, während E4M3 drei Bits auf die Mantisse verteilt und nur vier auf den Exponenten. Letzteres kann also genauer arbeiten, aber in einem kleineren Zahlenspektrum.
Daneben soll FP8 die Sonderfälle Null, unendlich und NaN (Not a Number), also einen undefinierten oder nicht darstellbaren Wert abdecken. Das FP8-Format ist wohl nativ in die Nvidia-Hopper-Architektur integriert.
Weitere Details zu dem Gleitkommazahlenformat FP8 und den Standardisierungsbemühungen lassen sich dem Nvidia-Blog entnehmen.
In der Formel x = m * 2e wurde die ursprünglich als 10 angegebene Basis auf b korrigiert.
(rme)