LLM Model Conversion
This is the main issue for discussions and information regarding the process of converting Llama models to .tflite format using Ai-Edge-Torch.
Current Status:
- [x] Llama 3.2 1B: Native Support*, Conversion confirmed working, Tested
- [x] Llama 3.2 3B: Native Support*, Conversion confirmed working, Tested
- [ ] Llama 3.1 8B: No Native Support*, To be Converted, To be Tested
* native support relates to the options ai-edge-torch provides by default.
Current conversion script is as described in this comment. Updates to the script will be posted in this issue.
Update
Conversion Memory Requirements
The Memory requirements for conversion seems to scale linearly, 1B requiring x amount of RAM (30GiB on average) with 3B requiring 3x (90GiB) and 8B requiring 8x (240GiB). Since I only have 64GB, The script kept terminating midway through (Although other than the memory limit, no issues were encountered with the conversion script).
There are 3 ways AFAIK of getting 3B and 8B to convert:
- Upgrade my system to include enough RAM (very difficult and expensive)
- Rent a server that has enough ram (costs around $1-2/hr)
- Colab (I don't have much experience with it, and I'm not sure it can run the
.shscript)
8B support
as stated earlier, 8B is not supported in AI-Edge-Torch's conversion script by default, but looking at the code for 1B and 3B, it seems to be purely config related, so simply writing (or copying) conversion config for 8B should allow AI-Edge-Torch to convert it without issue.
I have not been able to explore this due to the memory limitation.
@freedomtan @Mostelk What do you suggest?
I was able to convert 3B on colab with 56 GiB (CPU instance + high RAM). Thus I guess you over-estimate the memory requirements. My impression is that 64 GiB DRAM with pytorch CPU is enough to convert 3B ones (worst case is that you may need to add a some swap space).
@anhappdev can we get an EC2 instance to do this?
I was able to convert 3B on colab with 56 GiB (CPU instance + high RAM). Thus I guess you over-estimate the memory requirements. My impression is that 64 GiB DRAM with pytorch CPU is enough to convert 3B ones (worst case is that you may need to add a some swap space).
Could be a me thing but my system kept running out of RAM while trying to convert 3B (I even terminated my window manager and used TTL 😂).
My assumption is linear scale, but if You're saying 64GiB is enough for 3B, then 8B would require at a minimum 146GiB memory. Personally, I prefer to have extra unused RAM than run out or use swap.
I was able to run ai-edge-torch's converter for llama 3.1 8b, the keyword here is run, because I still ran out of memory, but I believe the changed config will produce a TFLite model with enough resources.
Here are the updated scripts
convert.sh
#!/usr/bin/bash
set -euo pipefail
########################################
# Config — edit these as needed
########################################
AI_EDGE_TORCH_VERSION="0.6.0" # <-- customize me
MODEL_REPO="meta-llama/Llama-3.1-8B"
VENV_DIR=".venv"
HF_CACHE_ROOT="${PWD}/.hf_custom_cache" # custom cache location
HF_TOKEN="hf_12345612331598673242abcdefg"
NUM_THREADS=30
USE_LOCAL_SCRIPTS=1 # NOTE If this is set, you need to make sure to have a 'convert_llama3_1b.py' file present.
arguments=(
--output_path "."
--model_size 8b
--prefill_seq_lens 8
--prefill_seq_lens 64
--prefill_seq_lens 128
--prefill_seq_lens 256
--prefill_seq_lens 512
--prefill_seq_lens 1024
--prefill_seq_lens 2048
--kv_cache_max_len 3072
#--quantize none
)
########################################
# Export Hugging Face related variables
export OMP_NUM_THREADS=${NUM_THREADS}
export MKL_NUM_THREADS=${NUM_THREADS}
export OPENBLAS_NUM_THREADS=${NUM_THREADS}
export NUMEXPR_NUM_THREADS=${NUM_THREADS}
export HF_TOKEN
export HF_HOME="${HF_CACHE_ROOT}"
export HUGGINGFACE_HUB_CACHE="${HF_CACHE_ROOT}/hub"
export TRANSFORMERS_CACHE="${HF_CACHE_ROOT}/transformers" # Deprecated
for d in "${HF_HOME}" "${HUGGINGFACE_HUB_CACHE}" "${TRANSFORMERS_CACHE}"; do
if [[ -d "$d" ]]; then
echo "INFO: Detected existing $d, reusing..."
else
mkdir -p "$d"
fi
done
# Find Python
PY=""
for cand in python3.10 python3 python; do
if command -v "$cand" >/dev/null 2>&1; then PY="$cand"; break; fi
done
if [[ -z "${PY}" ]]; then
echo "ERROR: Python not found." >&2
exit 127
fi
PY_VER="$("$PY" - <<'PY'
import sys
print(".".join(map(str, sys.version_info[:2])))
PY
)"
if [[ "${PY_VER}" != "3.10" ]]; then
echo "WARNING: Detected Python ${PY_VER}; this script expects Python 3.10.x. Proceeding anyway." >&2
fi
if [[ -d "${VENV_DIR}" && -f "${VENV_DIR}/bin/activate" ]]; then
echo "INFO: Detected existing ${VENV_DIR}, resusing..."
else
"$PY" -m venv "${VENV_DIR}"
fi
# shellcheck disable=SC1090
source "${VENV_DIR}/bin/activate"
python -m pip install -U pip setuptools wheel
python -m pip install \
"transformers==4.46.3" \
"accelerate==0.26.0" \
"ai-edge-torch==${AI_EDGE_TORCH_VERSION}"
# Ensure curl exists
if ! command -v curl >/dev/null 2>&1; then
echo "ERROR: 'curl' not found. Please install curl and re-run." >&2
exit 127
fi
# Run the download script
python - <<PY
import torch
from transformers import pipeline
import os
model_id = "$MODEL_REPO"
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
res = pipe("The key to life is")
PY
if (( USE_LOCAL_SCRIPTS )); then
if [[ -f "convert_llama3_1b.py" ]]; then
echo "Using local ai-edge-torch conversion script."
else
echo "ERROR: local ai-edge-torch conversion script could not be found. Please either have 'convert_llama3_1b.py' in your working directory, or disable 'USE_LOCAL_SCRIPTS'."
exit 2
fi
else
CONVERT_URL="https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/refs/tags/v${AI_EDGE_TORCH_VERSION}/ai_edge_torch/generative/examples/llama/convert_to_tflite.py"
curl -L -o convert_llama3_1b.py "${CONVERT_URL}"
fi
# Resolve the local snapshot path for the model
SNAPSHOT_DIR="$(python - <<PY
from huggingface_hub import snapshot_download
repo = "$MODEL_REPO"
try:
p = snapshot_download(repo, local_files_only=True)
except Exception:
# Fallback: allow a download if not present yet
p = snapshot_download(repo)
print(p)
PY
)"
arguments+=( --checkpoint_path "$SNAPSHOT_DIR" )
echo "Using checkpoint path: ${SNAPSHOT_DIR}"
echo "HF cache root: ${HF_CACHE_ROOT}"
echo "ai-edge-torch version: ${AI_EDGE_TORCH_VERSION}"
echo '========================================================================='
printf '%q ' python convert_llama3_1b.py "${arguments[@]}"; echo
echo '========================================================================='
# Run conversion script
python convert_llama3_1b.py "${arguments[@]}"
convert_llama3_1b.py (ignore the name)
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example of converting Llama 3.2 1B model to multi-signature tflite model."""
from absl import app
from llama import build_1b_model, build_3b_model, build_8b_model
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities import export_config
from ai_edge_torch.generative.utilities import loader
flags = converter.define_conversion_flags('llama')
_MODEL_SIZE = flags.DEFINE_enum(
'model_size',
'1b',
['1b', '3b', '8b'],
'The size of the model to verify.',
)
_BUILDER = {
'1b': build_1b_model,
'3b': build_3b_model,
'8b': build_8b_model,
}
def main(_):
checkpoint_path = flags.FLAGS.checkpoint_path
pytorch_model = _BUILDER[_MODEL_SIZE.value](
checkpoint_path,
custom_loader=loader.maybe_get_custom_loader(
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
),
mask_cache_size=converter.get_mask_cache_size_from_flags(),
)
converter.convert_to_tflite(
pytorch_model,
output_path=flags.FLAGS.output_path,
output_name_prefix=flags.FLAGS.output_name_prefix,
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
quantize=flags.FLAGS.quantize,
lora_ranks=flags.FLAGS.lora_ranks,
export_config=export_config.get_from_flags(),
)
if __name__ == '__main__':
app.run(main)
llama.py
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example of building Llama 3.2 models."""
from functools import partial
import math
from typing import Callable, Dict, Tuple
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import torch
TENSOR_NAMES = model_builder.TENSOR_NAMES
def _build_llama3_rope_cache(
input_pos: torch.Tensor,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
device: torch.device,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
max_seq_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
It's a modified version of attn_utils.build_rope_cache with additional
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
and Cos values with scaling factors for quick lookup during the inference.
Reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
Args:
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int): Rope base value.
condense_ratio (int): The ratio by which sequence indicies are condensed.
dtype (torch.dtype): Output tensor's data type.
device (torch.device): Output tensor's data type.
factor (float): Factor to scale theta down for tokens in long range in the
sequence.
low_freq_factor (float): Factor to determine if tokens are in long range
in the sequence.
high_freq_factor (float): Factor to determine if tokens are in short range
in the sequence.
max_seq_len (int): The original token sequence length before extending
ROPE to support longer sequence.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
low_freq_wavelen = max_seq_len / low_freq_factor
high_freq_wavelen = max_seq_len / high_freq_factor
wavelen = 2 * math.pi / theta
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
theta = torch.where(wavelen > low_freq_wavelen, theta / factor, theta)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (max_seq_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_theta = (1 - smooth_factor) * theta / factor + smooth_factor * theta
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
theta = torch.where(is_medium, smoothed_theta, theta)
seq_idx = input_pos / condense_ratio
idx_theta = torch.outer(seq_idx, theta)
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
return cos, sin
class Llama(model_builder.DecoderOnlyModel):
"""A Llama model built from the Edge Generative API layers.
Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
"""
pass
def get_1b_model_config() -> cfg.ModelConfig:
"""Returns the model config for a Llama 3.2-1B model."""
attn_config = cfg.AttentionConfig(
num_heads=32,
head_dim=64,
num_query_groups=8,
rotary_base=500000,
rotary_percentage=1.0,
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=8192,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
block_config = cfg.TransformerBlockConfig(
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)
max_seq_len = 8192
# Create the RoPE callable
build_rope = partial(
_build_llama3_rope_cache,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
max_seq_len=max_seq_len,
)
config = cfg.ModelConfig(
vocab_size=128256,
num_layers=16,
max_seq_len=max_seq_len,
embedding_dim=2048,
block_configs=block_config,
final_norm_config=norm_config,
build_rope=build_rope,
)
return config
def get_3b_model_config() -> cfg.ModelConfig:
"""Returns the model config for a Llama 3.2-3B model."""
config = get_1b_model_config()
# Llama 3.2 has only one block config.
attn_config = config.block_config(0).attn_config
attn_config.num_heads = 24
attn_config.head_dim = 128
config.num_layers = 28
config.embedding_dim = 3072
return config
def get_8b_model_config() -> cfg.ModelConfig:
config = get_1b_model_config()
attn_config = config.block_config(0).attn_config
attn_config.num_heads = 32
attn_config.head_dim = 128
attn_config.num_query_groups = 8
ff_config = config.block_config(0).ff_config
ff_config.intermediate_size = 14336
config.num_layers = 32
config.embedding_dim = 4096
return config
def get_fake_model_config() -> cfg.ModelConfig:
config = get_1b_model_config()
config.vocab_size = 128
config.num_layers = 2
# SmolLM has only one block config.
config.block_config(0).ff_config.intermediate_size = 64
return config
def _build_model(
checkpoint_path: str,
config: cfg.ModelConfig,
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> torch.nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=config,
tensor_names=TENSOR_NAMES,
model_class=Llama,
custom_loader=custom_loader,
mask_cache_size=mask_cache_size,
)
def build_1b_model(
checkpoint_path: str,
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> torch.nn.Module:
return _build_model(
checkpoint_path, get_1b_model_config(), custom_loader, mask_cache_size
)
def build_3b_model(
checkpoint_path: str,
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> torch.nn.Module:
return _build_model(
checkpoint_path, get_3b_model_config(), custom_loader, mask_cache_size
)
def build_8b_model(
checkpoint_path: str,
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> torch.nn.Module:
return _build_model(
checkpoint_path, get_8b_model_config(), custom_loader, mask_cache_size
)
@freedomtan to share the dynamic int8 tflite model later
@Mostelk let's discuss with Scott first (in the group meeting this week, to check available resources).
@freedomtan to check if he can run the script provided by @farook-edev on his personal Colab account.
@farook-edev llama 1b and 3b dynamic range int8 quantized by ai-edge-torch, https://drive.google.com/drive/folders/1ImWzf-Az5L_GrvZ2pZ21fxpJmdsH9oJm?usp=share_link
Dynamic int8 3B one is unlikely to run on production Pixel devices. I witnessed messages indicating “out of memory” on a Pixel 10 Pro, followed by a reboot.
@freedomtan @Mostelk I could rent a server and run the conversion script, it'll cost less than $5 and should produce an 8B .tflite model. LMK if you'd like me to do that.