jax icon indicating copy to clipboard operation
jax copied to clipboard

Entry function 'fusion_2' uses too much parameter space

Open guyvdbroeck opened this issue 6 years ago • 7 comments

Jax jit works fine most of the time, but occasionally I get the following error, which I don't know how to interpret.

2019-05-05 21:59:48.182440: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:125] Can't find ptxas binary. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version. 2019-05-05 21:59:48.182510: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:126] Searched in the following directories: 2019-05-05 21:59:48.182670: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:130] ./cuda_sdk_lib 2019-05-05 21:59:48.182692: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:130] /usr/local/cuda 2019-05-05 21:59:48.182715: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:130] . 2019-05-05 21:59:48.182733: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:132] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work. 2019-05-05 21:59:49.158345: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:625] failed to load PTX text as a module: CUDA_ERROR_INVALID_PTX: a PTX JIT compilation failed 2019-05-05 21:59:49.158413: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:630] error log buffer (185 bytes): ptxas application ptx input, line 34370; error : Entry function 'fusion_2' uses too much parameter space (0x1898 bytes, 0x1100 max). ptxas fatal : Ptx assembly aborted due to error 2019-05-05 21:59:49.158640: F external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:888] Check failed: module != nullptr Aborted

guyvdbroeck avatar May 05 '19 22:05 guyvdbroeck

It would be helpful if you could share a small reproduction of the problem — this is a problem we should probably hand off to the fine folks who work on XLA.

hawkinsp avatar May 06 '19 00:05 hawkinsp

It would also be helpful to know what versions of jax and jaxlib you have installed.

hawkinsp avatar May 06 '19 17:05 hawkinsp

I was using the git version from last weekend. Unfortunately the code is quite complex and I cannot share it at the moment.

guyvdbroeck avatar May 07 '19 10:05 guyvdbroeck

I think it's going to be hard to make progress on this one without a way to reproduce it. We have some guesses as to how to trigger it, but we aren't completely sure.

hawkinsp avatar May 10 '19 19:05 hawkinsp

Thanks. I believe the underlying cause is that the computation graph I give to jit is simply too large (about 10,000 np operations).

guyvdbroeck avatar May 16 '19 20:05 guyvdbroeck

I managed to isolate the code that produces the error: bug.zip As I thought, it's probably due to the jit compiler running out of memory. The file also shows a way in which I'm trying to avoid this blow-up of the compiler, by isolating sub-functions to be compiled. This works, but compilation is still very slow, and it would be nice if there were ways to speed it up or parallelize simultaneous jit of many functions on the CPU (cf. #679).

guyvdbroeck avatar May 17 '19 06:05 guyvdbroeck

Thanks for the reproducer. I have played around with it, and tried several layer sizes. With your workaround you are making sure that several XlaModules are created. The size of the emitted LLVM IR code grows linear with the size of the XLAModule, however it seems that the PTXAS memory usage is more than linear in the size of the IR code. PTXAS is not under our control, but arguably XLA or Jax should automatically split modules that are too large for PTXAS to handle.

akuegel avatar May 28 '19 13:05 akuegel

Good news, at long last! This bug doesn't repro anymore :) I think the compiler was fixed.

Compilation is still super slow, and there are probably ways we'd want to adapt this code to make it more amenable to JAX. But for the narrow scope of this bug, we're good!

For posterity, I tested with the below revised version of the script (just to replace index_update with .at[...].set, and handle lists in static_argnums arguments). It executes to completion without a compiler error on an A100 VM, producing the output below:

$ python bug.py

Jit compiling with layer size 128
1st batch took 144.22 sec
2nd batch took 135.62 sec
It would be great to use multi-core for the initial jit compilation of all independent layers, in order to reduce this huge upfront cost.
Jit compiling with layer size 10000 (this will trigger a bug in the compiler, and use about 60GB of RAM on my server)
2022-08-15 19:02:58.365163: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65]
********************************
[Compiling module jit__lambda_.143] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2022-08-15 19:03:28.264408: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2m29.899393287s

********************************
[Compiling module jit__lambda_.143] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
$ 
# -*- coding: utf-8 -*-

import numpy as onp
import time
import itertools

import jax.numpy as np
from jax import jit, device_put
from jax import random
# from jax.ops import index, index_add, index_update

def pos_lit(v):
  return v-1

def neg_lit(v, num_features):
  return num_features+v-1

def parse_lc(lines, num_features):
  wiring = list(parse_lc_gen(lines, num_features))
  return (num_features, len(wiring)+2*num_features, -1, wiring)

def parse_lc_gen(lines, num_features):
  node_map = {}
  inode_ct = itertools.count(2*num_features)

  def inode(id1):
    if not id1 in node_map:
      inode = next(inode_ct)
      node_map[id1] = inode
    return node_map[id1]

  for x in lines:
    spl = x.split()
    if x.startswith('T'):
        id1 = spl[1]
        v = int(spl[3])
        node_map[id1] = pos_lit(v)
    if x.startswith('F'):
        id1 = spl[1]
        v = int(spl[3])
        node_map[id1] = neg_lit(v, num_features)
    if x.startswith('D'):
        id1 = spl[1]
        elements = [el.split(" ") for el in x.replace(')', '').split("(")[1:]]
        elements = [[inode(c) for c in e[0:2]] for e in elements]
        yield [inode(id1), 'D', elements, -1]

def parse_lc_file(file, num_features):
  return parse_lc(open(file, "r"), num_features)

def one_hot(x, k):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), np.float32)

key = random.PRNGKey(0)

my_batch_size = 1024
num_features = 28*28
num_labels = 10

big_circuit = parse_lc_file("mnist.circuit",num_features)

def batch_node_probs(inputs, circuit, layer_size):
  num_inputs, num_gates, num_wires, wiring = circuit
  probs = j_batch_node_probs_init(inputs, num_inputs, num_gates)
  for i in range(0, len(wiring), layer_size):
    probs = j_batch_node_probs_layer(probs, i, i+layer_size, wiring)
  return probs

def batch_node_probs_layer(probs, start, end, wiring):
  for node in wiring[start:end]:
    op = node[1]
    if(op == 'D'):
      id, op, elements, weights = node
      probs = batch_node_probs_dec(probs, id, elements)
    else: raise Exception(f"unknown operand: {op}")
  return probs

def j_batch_node_probs_layer(probs, start, end, wiring):
    return jit(lambda probs: batch_node_probs_layer(probs, start, end, wiring))(probs)
# j_batch_node_probs_layer = jit(batch_node_probs_layer, static_argnums=(1,2,3))

def batch_node_probs_init(inputs, num_inputs, num_gates):
  num_instances = inputs.shape[0]
  probs = np.empty((num_gates,num_instances))
  probs = probs.at[0:num_inputs].set(np.transpose(inputs))
  return probs.at[num_inputs:2*num_inputs].set(1-probs[0:num_inputs])
  # probs = index_update(probs, index[0:num_inputs], np.transpose(inputs))
  # return index_update(probs, index[num_inputs:2*num_inputs],

j_batch_node_probs_init = jit(batch_node_probs_init, static_argnums=(1,2))

def batch_node_probs_dec(probs, id, elements):
  el_probs = [np.prod(probs[np.array(el)], axis=0) for el in elements]
  el_probs = np.stack(el_probs, axis=0)
  return probs.at[id].set(np.sum(el_probs, axis=0))
  # return index_update(probs, index[id], np.sum(el_probs, axis=0))

# random data
random_flattened_images = device_put(random.normal(random.PRNGKey(1), (my_batch_size, num_features)))
random_targets = device_put(one_hot(onp.random.randint(0,10,size=(my_batch_size,)),num_labels))

layer_size = 128
print(f"Jit compiling with layer size {layer_size}")

start_time = time.time()
batch_node_probs(random_flattened_images, big_circuit, layer_size)
print("1st batch took {:0.2f} sec".format(time.time() - start_time))

start_time = time.time()
batch_node_probs(random_flattened_images, big_circuit, layer_size)
print("2nd batch took {:0.2f} sec".format(time.time() - start_time))
print("It would be great to use multi-core for the initial jit compilation of all independent layers, in order to reduce this huge upfront cost.")

layer_size = 10000
print(f"Jit compiling with layer size {layer_size} (this will trigger a bug in the compiler, and use about 60GB of RAM on my server)")

batch_node_probs(random_flattened_images, big_circuit, layer_size)

mattjj avatar Aug 15 '22 19:08 mattjj