tiktoken icon indicating copy to clipboard operation
tiktoken copied to clipboard

Replace js-tiktoken BPE merge algorithm with faster heap based algorithm

Open mikolalysenko opened this issue 1 year ago • 6 comments

There are several open issues noting that in the worst case BPE merge algorithm in js-tiktoken takes quadratic time in the number of input characters for certain pathological inputs.

This PR fixes this problem using a heap to avoid recalculating the ranks of all tokens at each character. This technique should also work for the rust/wasm tokenizer but it seems less important in those cases since the native parsers are already pretty fast.

I also added a new test fixture and an example string which causes pathological behavior.

Related issues:

  • https://github.com/dqbd/tiktoken/issues/99
  • https://github.com/openai/tiktoken/issues/195

Note: This should be a mild CVE since an attacker may use this behavior to cause a denial of service against services that check user input with js-tiktoken.

mikolalysenko avatar Apr 19 '24 21:04 mikolalysenko

@dqbd are we considering merging this, we notice some performance issues if the input is large as well

huytool157 avatar May 07 '24 15:05 huytool157

@mikolalysenko have you benchmarked this vs baseline vs wasm? curious for 1, 1000, 1M, 100M tokens.

enricoros avatar Jun 05 '24 04:06 enricoros

have you benchmarked this vs baseline vs wasm? curious for 1, 1000, 1M, 100M tokens.

If there is a perf difference at different scales, we should consider toggling algorithms based on length.

jonschlinkert avatar Jul 04 '24 01:07 jonschlinkert

It would be lovely if this got merged - this module has been causing issues for my application in prod.

tmcw avatar Nov 20 '24 15:11 tmcw

+1

danny-avila avatar Dec 30 '24 15:12 danny-avila

I am not sure if this fully works, but modify core.ts to this loc

import base64 from "base64-js";
import type { TiktokenModel } from "./ranks/ranks";
import { never } from "./utils";

type BPEMergeNode = {
  listNext: BPEMergeNode | null;
  listPrev: BPEMergeNode | null;

  deleted: boolean;
  updated: boolean;
  updatedRank: number;
  removed: boolean;

  rank: number;
  start: number;
  end: number;
};

function compareNode(a: BPEMergeNode, b: BPEMergeNode) {
  return a.rank - b.rank || a.start - b.start;
}

// Helper function to swap elements at two indices
function swap(heap: BPEMergeNode[], i: number, j: number) {
  const temp = heap[i];
  heap[i] = heap[j];
  heap[j] = temp;
}

// standard binary heap push, generated by gpt4
function heapPush(heap: BPEMergeNode[], part: BPEMergeNode) {
  heap.push(part); // Add the new element to the end
  let currentIndex = heap.length - 1;
  let parentIndex = Math.floor((currentIndex - 1) / 2);

  // Bubble the new element up to its correct position
  while (
    currentIndex > 0 &&
    compareNode(heap[currentIndex], heap[parentIndex]) < 0
  ) {
    swap(heap, currentIndex, parentIndex);
    currentIndex = parentIndex;
    parentIndex = Math.floor((currentIndex - 1) / 2);
  }
}

// standard heap pop, also ai generated
function heapPop(heap: BPEMergeNode[]) {
  if (heap.length === 0) {
    return undefined; // Return undefined if the heap is empty
  }

  const rootValue = heap[0]; // The root element to return
  const lastValue = heap.pop(); // Remove the last element

  if (heap.length > 0 && lastValue) {
    heap[0] = lastValue; // Move the last element to the root
    let currentIndex = 0;

    // Bubble down the new root element to its correct position
    while (true) {
      let leftChildIndex = 2 * currentIndex + 1;
      let rightChildIndex = 2 * currentIndex + 2;
      let smallestIndex = currentIndex;

      if (
        leftChildIndex < heap.length &&
        compareNode(heap[leftChildIndex], heap[smallestIndex]) < 0
      ) {
        smallestIndex = leftChildIndex;
      }

      if (
        rightChildIndex < heap.length &&
        compareNode(heap[rightChildIndex], heap[smallestIndex]) < 0
      ) {
        smallestIndex = rightChildIndex;
      }

      if (smallestIndex !== currentIndex) {
        swap(heap, currentIndex, smallestIndex);
        currentIndex = smallestIndex;
      } else {
        break;
      }
    }
  }

  return rootValue;
}

function bytePairMerge(
  piece: Uint8Array,
  ranks: Map<string, number>
): Array<{ start: number; end: number }> {
  const parts: BPEMergeNode[] = Array.from(
    { length: piece.length },
    (_, i) => ({
      start: i,
      end: i + 1,
      rank: Infinity,
      deleted: false,
      updated: false,
      updatedRank: 0,
      removed: true,
      listNext: null,
      listPrev: null,
    })
  );

  if (parts.length === 0) {
    return [];
  }

  // Initialize linked list
  const head = parts[0];
  for (let i = 0; i < parts.length; ++i) {
    parts[i].listPrev = parts[i - 1] ?? null;
    parts[i].listNext = parts[i + 1] ?? null;
  }

  // Initialize heap with valid merges
  const heap: BPEMergeNode[] = [];
  for (let i = 0; i < parts.length - 1; ++i) {
    const slice = piece.slice(parts[i].start, parts[i + 1].end);
    const rank = ranks.get(slice.join(","));
    if (rank == null) continue;
    const part = parts[i];
    part.removed = false;
    part.rank = rank;
    heapPush(heap, part);
  }

  while (heap.length > 0) {
    const part = heapPop(heap);
    if (!part) break;

    if (part.deleted || !part.listNext) {
      continue;
    }

    if (part.updated) {
      part.rank = part.updatedRank;
      part.updated = false;
      heapPush(heap, part);
      continue;
    }

    // Verify the merge is still valid
    const currentSlice = piece.slice(part.start, part.listNext.end);
    const currentRank = ranks.get(currentSlice.join(","));
    if (currentRank !== part.rank) {
      continue;
    }

    // Perform merge
    part.end = part.listNext.end;
    part.listNext.deleted = true;
    part.listNext = part.listNext.listNext;
    if (part.listNext) {
      part.listNext.listPrev = part;
    }

    // Check for new possible merges
    let addedNewMerge = false;
    if (part.listNext) {
      const slice = piece.slice(part.start, part.listNext.end);
      const rank = ranks.get(slice.join(","));
      if (rank != null) {
        part.rank = rank;
        part.removed = false;
        heapPush(heap, part);
        addedNewMerge = true;
      }
    }

    if (part.listPrev && !part.listPrev.deleted) {
      const slice = piece.slice(part.listPrev.start, part.end);
      const rank = ranks.get(slice.join(","));
      if (rank != null) {
        if (!part.listPrev.removed) {
          part.listPrev.updated = true;
          part.listPrev.updatedRank = rank;
        } else {
          part.listPrev.removed = false;
          part.listPrev.rank = rank;
          heapPush(heap, part.listPrev);
        }
        addedNewMerge = true;
      }
    }

    if (!addedNewMerge) {
      part.removed = true;
    }
  }

  const result: Array<{ start: number; end: number }> = [];
  let current: BPEMergeNode | null = head;
  while (current) {
    if (!current.deleted) {
      result.push({ start: current.start, end: current.end });
    }
    current = current.listNext;
  }
  return result;
}


// rest of code unchanged

It will then pass tests since in the current PR there is a decoding error with non-latin chars

FAIL  test/compatibility.test.ts > LiteTokenizer matches the behavior of tiktoken > Emojis and non-latin characters
js-tiktoken:test: AssertionError: expected [ 9468, 239, 102, 378, 235, …(109) ] to deeply equal [ 9468, 239, 102, 378, 235, …(111) ]
js-tiktoken:test:  ❯ test/compatibility.test.ts:50:38
js-tiktoken:test:      48| 
js-tiktoken:test:      49|     for (const text of fixtures) {
js-tiktoken:test:      50|       expect([...lite.encode(text)]).toEqual([...full.encode(text)]);
js-tiktoken:test:        |                                      ^
js-tiktoken:test:      51|     }
js-tiktoken:test:      52|   });

With this lite_performance.text.js

import { test, expect, describe, } from "vitest";
import { getEncoding } from "../src/index";
const TARGET_TIME = 30_000;
const TARGET_STRING_LENGTH = 1_000_000; // Crazy high number to test the limits of the lite tokenizer
const EVIL_STRING = Array.from({ length: TARGET_STRING_LENGTH }, () => {
  return String.fromCharCode(Math.floor(Math.random() * 256));
}).join("");

// This test will be flaky - so perhaps we should run it externally
// from the main CI pipeline since it depends on the machine it's run on
describe(`Lite tokenizer resolves ${EVIL_STRING.length / 1000}K string in acceptable time (${TARGET_TIME}ms)`, () => {
  const lite = getEncoding("cl100k_base");
  test("Test lite performance", () => {
    const start = Date.now();
    const result = lite.encode(EVIL_STRING);
    const end = Date.now();
    console.log(`Lite encoding time: ${end - start}ms`);
    expect(end - start).toBeLessThanOrEqual(TARGET_TIME);
  });

  test("Test encoding/decoding", () => {
    const result = lite.encode(EVIL_STRING);
    const decoded = lite.decode(result);
    expect(decoded).toEqual(EVIL_STRING);
  });
});

With a crazy length of 1_000_000 chars I get encoding in 1908ms on an intel MBP

timothycarambat avatar Feb 19 '25 22:02 timothycarambat