[Bug]: When i use the following code and change the FlaxGPT2Block to float16" ", it output "I enjoy walking with my cute dog cement Facility Users LankaPHOTOS henceitial periodsGener observations
Issue Type
Build/Install
Modules Involved
Others
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
spu0.9.1b0
OS Platform and Distribution
linux18.04
Python Version
3.10.14
Compiler Version
GCC11.2.1
Current Behavior?
When i use the following code and change the FlaxGPT2Block to float16" ", it output "I enjoy walking with my cute dog cement Facility Users LankaPHOTOS henceitial periodsGener observations", which is different with the right answer. Hope to get your advice, thank you!!! import sys from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config tokenizer = AutoTokenizer.from_pretrained("gpt2")
def text_generation(input_ids, params): config = GPT2Config() model = FlaxGPT2LMHeadModel(config=config) for _ in range(10): outputs = model(input_ids=input_ids, params=params) next_token_logits = outputs[0][0, -1, :] next_token = jnp.argmax(next_token_logits) input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1) return input_ids import secretflow as sf from typing import Any, Callable, Dict, Optional, Tuple, Union import jax.nn as jnn import flax.linen as nn from flax.linen.linear import Array import jax import argparse import spu.utils.distributed as ppd import spu.intrinsic as intrinsic import spu.spu_pb2 as spu_pb2 from contextlib import contextmanager import jax.numpy as jnp
copts = spu_pb2.CompilerOptions() copts.enable_pretty_print = False copts.xla_pp_kind = 2
enable x / broadcast(y) -> x * broadcast(1/y) copts.enable_optimize_denominator_with_broadcast = True Array = Any
In case you have a running secretflow runtime already. sf.shutdown() def hack_softmax( x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None, ) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) x = x - x_max
exp on large negative is clipped to zero
b = x > -14 nexp = jnp.exp(x)
divisor = jnp.sum(nexp, axis, where=where, keepdims=True)
return b * (nexp / divisor) https://github.com/contextmanager def hack_softmax_context(msg: str, enabled: bool = False): if not enabled: yield return
hijack some target functions
raw_softmax = jnn.softmax jnn.softmax = hack_softmax yield
recover back
jnn.softmax = raw_softmax def hack_gelu( x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None, ) -> Array: b0 = x < -4.0 b1 = x < -1.95 b2 = x > 3.0 b3 = b1 ^ b2 ^ True # x in [-1.95, 3.0] b4 = b0 ^ b1 # x in [-4, -1.95]
seg1 = a[3] * x^3 + a[2] * x^2 + a[1] * x + a[0]
seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[1] * x + b[0]
a_coeffs = jnp.array( [ -0.5054031199708174, -0.42226581151983866, -0.11807612951181953, -0.011034134030615728, ] ) b_coeffs = jnp.array( [ 0.008526321541038084, 0.5, 0.3603292692789629, 0.0, -0.037688200365904236, 0.0, 0.0018067462606141187, ] ) x2 = jnp.square(x) x3 = jnp.multiply(x, x2) x4 = jnp.square(x2) x6 = jnp.square(x3)
seg1 = a_coeffs[3] * x3 + a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0] seg2 = ( b_coeffs[6] * x6 + b_coeffs[4] * x4 + b_coeffs[2] * x2 + b_coeffs[1] * x + b_coeffs[0] )
ret = b2 * x + b4 * seg1 + b3 * seg2
return ret https://github.com/contextmanager def hack_gelu_context(msg: str, enabled: bool = False): if not enabled: yield return
hijack some target functions
raw_gelu = jnn.gelu jnn.gelu = hack_gelu yield
recover back
jnn.gelu = raw_gelu sf.init(['alice', 'bob', 'carol'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob') conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol']) conf['runtime_config']['protocol'] = 'ABY3' conf['runtime_config']['field'] = 'FM64' conf['runtime_config']['fxp_exp_mode'] = 0 conf['runtime_config']['fxp_exp_iters'] = 5
spu = sf.SPU(conf)
def get_model_params(): pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2") return pretrained_model.params
def get_token_ids(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
model_params = alice(get_model_params)() input_token_ids = bob(get_token_ids)()
device = spu model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)
with hack_softmax_context("hijack jax softmax", enabled=True), hack_gelu_context( "hack jax gelu", enabled=True ): output_token_ids = spu(text_generation, copts=copts)( input_token_ids_, model_params_ ) outputs_ids = sf.reveal(output_token_ids) print('-' * 65 + '\nRun on SPU:\n' + '-' * 65) print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True)) print('-' * 65)
Standalone code to reproduce the issue
print("A bug")
Relevant log output
No response