Python: Richtige JIT-Optimierung – damit nichts schiefgeht

Seite 2: Wie die Optimierung zustande kommt

Inhaltsverzeichnis

Wie kommt nun die Optimierung zustande? Dafür ist das Paket Numba zuständig – über from numba import jit lässt sich der dort definierte @jit-Dekorator importieren. Die Schwierigkeit in der JIT-optimierten Programmierung besteht darin, dass die Verwendung von Befehlen eingeschränkt ist und Entwickler nicht einfach beliebigen Code optimieren können. Deshalb muss er optimierbare Operationen in eine dafür gesondert vorgesehene Routine auslagern. Im vorliegenden Fall ist es die Funktion generateData, die mit dem @jit-Dekorator versehen ist.

Die Routine sucht mithilfe dreier verschachtelter for-Schleifen, in denen sie einfache Operationen wie Addition, Subtraktion, Multiplikation und Wurzelziehen verwendet, nach den Zahlen x, y, z. Am Ende gibt sie die Anzahl der gefunden Ergebnisse zurück. Das Speichern der gefundenen Lösungen in eine CSV-Datei erfolgt außerhalb dieser Routine. Fügen Programmierer derartigen, beispielsweise auf pandas basierenden Code zur Erzeugung von CSV-Dateien in die per @jit-Dekorator versehene Methode ein, erhalten sie beim Aufruf des Python-Programms einen Fehler. Der Grund dafür ist, dass derartiger Code nicht optimierbar ist. Allerdings existiert eine Reihe von Fallstricken, die zwar keinen Fehler, aber einen enormen Performance-Verlust und falsche Ergebnisse verursachen.

Der Code in Listing 1 ist so beschaffen, dass er ein leeres, bereits mit vorgegebener Größe initialisiertes Array an die optimierte Routine übergibt. Das Programm befüllt dieses Array innerhalb der Route selbst. Alternativ könnten Entwickler dem Gedanken verfallen, das Ergebnis-Array dynamisch mit den gefundenen Lösungen wachsen zu lassen. Eine Suche von Lösungen, deren Werte x, y, z maximal 50000 annehmen, ergibt viel weniger Ergebnisse als die festgelegte Grenze von 50000 – in der Tat sind 1074 Treffer deutlich weniger als 50000. Auf diese Weise würde das Programm auch keinen Platz verschwenden.

Der Programm-Code im Listing 2 realisiert diese Platzersparnis und hält sich dabei an die Vorgabe, dass er an die optimierte Routine ein initialisiertes Array übergeben soll. Allerdings zerstört diese Vorgehensweise den gesamten Performancegewinn und eine Suche (wieder mit limit = 50000) dauert nun 1203,3 Sekunden (20 Minuten). Es ist also von Vorteil, wenn sich Programmierer auf diese "Platzverschwendung" einlassen. Wer dieses Beispiel selbst nachstellen möchte, findet den vollständigen Code im GitHub-Repository (siehe Datei pythagorean_gendata_jit_antipattern.py).

arr = np.array([x, y, z, s, t, w], dtype=np.uint64)
old_size = triples.shape
rows = np.uint32(old_size[0])
cols = np.uint32(old_size[1])
triples.resize((rows + 1, cols), refcheck=False)
triples[rows] = arr

Listing 2: Fallstrick 1 – "Teure" Operationen in der optimierten Routine.

Bei der Gegenüberstellung der Laufzeiten, wie sie in Tabelle 1 zu sehen sind, ergibt sich ein beeindruckender Vergleich: Die korrekt umgesetzte Optimierung bringt eine Beschleunigung um den Faktor 225.

Keine Optimierung
Optimierung
Optimierung mit Fallstrick 1
1082,8 Sekunden
4,8 Sekunden
1203,3 Sekunden

Tabelle 1: Gegenüberstellung der Laufzeiten

Setzen Programmierer eine lang laufende Routine ein, deren Laufzeit nicht nur Tage, sondern Wochen dauert, ist es wünschenswert, in gewissen Zeitabständen Zwischenergebnisse zu sichern – beispielsweise in einer CSV-Datei. Der Schreibvorgang darf aber nicht innerhalb der optimierten Routine, sondern muss außerhalb erfolgen. Listing 3 zeigt eine Routine mit drei ineinander verschachtelten for-Schleifen zur alternativen Generierung potenzieller Tupel für das Sechs-Quadrate-Problem. Der vollständige Code ist im GitHub-Repository (siehe Datei pythagorean_gendata2_nofile_jit.py) verfügbar. Dieser Code lässt sich per JIT gut optimieren und bietet die die erwartete Beschleunigung um dreistellige Faktoren.

@jit('void(uint64)')
def generateData(limit: np.uint64):
    for t in np.arange(0, limit+1, dtype=np.uint64):
        for s in np.arange(0, t, dtype=np.uint64):
            for u in np.arange(0, limit+1, dtype=np.uint64)

Listing 3: Ineinander geschachtelte For-Schleifen mit festen Bereichen kann JIT optimieren.

Speichern Entwickler die gefunden Ergebnisse in einer Datei zwischen, entfällt dabei die gesamte Optimierung. Jede Fundstelle wird an den (außerhalb der JIT-Optimierung laufenden) Aufrufer zurückgegeben, der sie wiederum in eine Datei schreibt (Listing 4). Der Mechanismus hebt die gesamte JIT-Optimierung auf. Im GitHub-Repository liegt der vollständige Code zum Ausprobieren bereit (siehe Datei pythagorean_gendata2_jit_antipattern.py).

@jit('void(uint64, uint64[:])')
def generateData(limit: np.uint64, triplet: np.ndarray) -> np.ndarray:
    s0 = triplet[0]
    t0 = triplet[1]
    u0 = triplet[2]
    for t in np.arange(t0, limit+1, dtype=np.uint64):
        for s in np.arange(s0, t, dtype=np.uint64):
            for u in np.arange(u0, limit+1, dtype=np.uint64):
               �
               if sqr*sqr == t_s:
                   return np.array([s, t, u, ss, tt, uu, t_u, t_u_s, t_s], dtype=np.uint64)

Listing 4: Aufheben der Optimierung durch dynamische Schleifen (Fallstrick 2).

Der dritte Fallstrick zeigt sich anhand eines in der Kryptographie verbreiteten Problems: Elliptische Kurven. Eine häufige Fragestellung, die direkt damit im Zusammenhang steht, lautet: Liegen auf der Kurve rationale oder sogar ganzzahlige Punkte? Die Frage wurde beispielsweise auf StackExchange Mathematics für die Kurve y2=x6−4x2+4 gestellt und beantwortet. Es sei an dieser Stelle erwähnt, dass es mächtigere mathematische (beispielsweise probabilistische) Methoden für die Suche von rationalen Punkten auf elliptischen Kurven gibt. Zu vereinfachten Illustrationszwecken des Performancegewinns via JIT und GPU bleibt dieses Beispiel beim einfachen Bruteforce.Die Routine in Listing 5, die in diesem Code korrekt optimiert ist, dient zur Suche von ganzzahligen Lösungen dieser Kurve.

@jit('void(uint64)')
def findIntegerSolutions(limit: np.uint64):
    for x in np.arange(0, limit+1, dtype=np.uint64):
        y = np.uint64(x**6-4*x**2+4)
        sqr = np.uint64(np.sqrt(y))
        if np.uint64(sqr*sqr) == y:
            print([x,sqr,y])

Listing 5: Optimierte Routine zur Suche von ganzzahligen Punkten auf der Kurve y2=x6−4x2+4