spu icon indicating copy to clipboard operation
spu copied to clipboard

[Bug]: When i use the following code and change the FlaxGPT2Block to float16" ", it output "I enjoy walking with my cute dog cement Facility Users LankaPHOTOS henceitial periodsGener observations

Open zhangwaer opened this issue 1 year ago • 0 comments

Issue Type

Build/Install

Modules Involved

Others

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu0.9.1b0

OS Platform and Distribution

linux18.04

Python Version

3.10.14

Compiler Version

GCC11.2.1

Current Behavior?

When i use the following code and change the FlaxGPT2Block to float16" ", it output "I enjoy walking with my cute dog cement Facility Users LankaPHOTOS henceitial periodsGener observations", which is different with the right answer. Hope to get your advice, thank you!!! import sys from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config tokenizer = AutoTokenizer.from_pretrained("gpt2")

def text_generation(input_ids, params): config = GPT2Config() model = FlaxGPT2LMHeadModel(config=config) for _ in range(10): outputs = model(input_ids=input_ids, params=params) next_token_logits = outputs[0][0, -1, :] next_token = jnp.argmax(next_token_logits) input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1) return input_ids import secretflow as sf from typing import Any, Callable, Dict, Optional, Tuple, Union import jax.nn as jnn import flax.linen as nn from flax.linen.linear import Array import jax import argparse import spu.utils.distributed as ppd import spu.intrinsic as intrinsic import spu.spu_pb2 as spu_pb2 from contextlib import contextmanager import jax.numpy as jnp

copts = spu_pb2.CompilerOptions() copts.enable_pretty_print = False copts.xla_pp_kind = 2

enable x / broadcast(y) -> x * broadcast(1/y) copts.enable_optimize_denominator_with_broadcast = True Array = Any

In case you have a running secretflow runtime already. sf.shutdown() def hack_softmax( x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None, ) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) x = x - x_max

exp on large negative is clipped to zero

b = x > -14 nexp = jnp.exp(x)

divisor = jnp.sum(nexp, axis, where=where, keepdims=True)

return b * (nexp / divisor) https://github.com/contextmanager def hack_softmax_context(msg: str, enabled: bool = False): if not enabled: yield return

hijack some target functions

raw_softmax = jnn.softmax jnn.softmax = hack_softmax yield

recover back

jnn.softmax = raw_softmax def hack_gelu( x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None, ) -> Array: b0 = x < -4.0 b1 = x < -1.95 b2 = x > 3.0 b3 = b1 ^ b2 ^ True # x in [-1.95, 3.0] b4 = b0 ^ b1 # x in [-4, -1.95]

seg1 = a[3] * x^3 + a[2] * x^2 + a[1] * x + a[0]

seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[1] * x + b[0]

a_coeffs = jnp.array( [ -0.5054031199708174, -0.42226581151983866, -0.11807612951181953, -0.011034134030615728, ] ) b_coeffs = jnp.array( [ 0.008526321541038084, 0.5, 0.3603292692789629, 0.0, -0.037688200365904236, 0.0, 0.0018067462606141187, ] ) x2 = jnp.square(x) x3 = jnp.multiply(x, x2) x4 = jnp.square(x2) x6 = jnp.square(x3)

seg1 = a_coeffs[3] * x3 + a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0] seg2 = ( b_coeffs[6] * x6 + b_coeffs[4] * x4 + b_coeffs[2] * x2 + b_coeffs[1] * x + b_coeffs[0] )

ret = b2 * x + b4 * seg1 + b3 * seg2

return ret https://github.com/contextmanager def hack_gelu_context(msg: str, enabled: bool = False): if not enabled: yield return

hijack some target functions

raw_gelu = jnn.gelu jnn.gelu = hack_gelu yield

recover back

jnn.gelu = raw_gelu sf.init(['alice', 'bob', 'carol'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob') conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol']) conf['runtime_config']['protocol'] = 'ABY3' conf['runtime_config']['field'] = 'FM64' conf['runtime_config']['fxp_exp_mode'] = 0 conf['runtime_config']['fxp_exp_iters'] = 5

spu = sf.SPU(conf)

def get_model_params(): pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2") return pretrained_model.params

def get_token_ids(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')

model_params = alice(get_model_params)() input_token_ids = bob(get_token_ids)()

device = spu model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)

with hack_softmax_context("hijack jax softmax", enabled=True), hack_gelu_context( "hack jax gelu", enabled=True ): output_token_ids = spu(text_generation, copts=copts)( input_token_ids_, model_params_ ) outputs_ids = sf.reveal(output_token_ids) print('-' * 65 + '\nRun on SPU:\n' + '-' * 65) print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True)) print('-' * 65)

Standalone code to reproduce the issue

print("A bug")

Relevant log output

No response

zhangwaer avatar Aug 06 '24 07:08 zhangwaer