llama
llama copied to clipboard
How do I run the model on a Jupyter Notebook environment?
I'm trying to run the model on a Jupyter Notebook but I'm not sure how to go by this. Is anyone working on this? I would really appreciate some tips.
(P.S I'm a TensorFlow developer and trying to recreate the model architecture using the Keras API. If someone is working on that as well, any help is much appreciated.)
from typing import Optional, Tuple
from dataclasses import dataclass
import math
import torch
from torch import nn
import torch.nn.functional as F
@dataclass
class ModelArgs:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 1024
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim -
1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_local_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
)
self.wk = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
)
self.wv = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
)
self.wo = nn.Linear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
)
self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len,
self.n_local_heads, self.head_dim)
)
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len,
self.n_local_heads, self.head_dim)
)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(
1, 2
).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * \
((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = torch.nn.Linear(
dim, hidden_dim, bias=False,
)
self.w2 = torch.nn.Linear(
hidden_dim, dim, bias=False,
)
self.w3 = torch.nn.Linear(
dim, hidden_dim, bias=False,
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
print("-- Creating embedding")
self.tok_embeddings = torch.nn.Embedding(
params.vocab_size, params.dim
)
self.layers = torch.nn.ModuleList()
print(f"-- Creating transformer blocks ({params.n_layers})")
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
print("-- Adding output layers ")
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(
params.dim, params.vocab_size, bias=False
)
print("-- Precomputing frequencies")
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen),
float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
for layer in self.layers:
#print(f"-- Computing layer {layer.layer_id}")
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h[:, -1, :]) # only compute last logits
return output.float()
from sentencepiece import SentencePieceProcessor
from logging import getLogger
from typing import List
import os
logger = getLogger()
class Tokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
logger.info(f"Reloaded SentencePiece model from {model_path}")
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
return self.sp_model.decode(t)
from typing import List
import torch
class LLaMA:
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> List[str]:
bsz = len(prompts)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
prompt_tokens = [self.tokenizer.encode(
x, bos=True, eos=False) for x in prompts]
min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
print(f"Forwarding {total_len} times")
tokens = torch.full(
(bsz, total_len), self.tokenizer.pad_id).long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
# print(f"Feeding tensors forward #{cur_pos}")
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[: len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
return decoded
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA:
start_time = time.time()
print("Locating checkpoints")
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert (
world_size == len(checkpoints)
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
print(f"Found MP={len(checkpoints)} checkpoints")
ckpt_path = checkpoints[local_rank]
print("Creating checkpoint instance...")
checkpoint = torch.load(ckpt_path, map_location="cpu")
print("Grabbing params...")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
print("Loading model arguments...")
model_args: ModelArgs = ModelArgs(
max_seq_len=1024, max_batch_size=32, **params)
print("Creating tokenizer...")
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
print("Creating transformer...")
#torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading checkpoint to model...", end="")
_start_time = time.time()
#torch.set_default_tensor_type(torch.BFloat16Tensor)
model.load_state_dict(checkpoint, strict=False)
print(f"done in {time.time() - _start_time:.2f} seconds")
_start_time = time.time()
print("Creating LLaMA generator...", end="")
generator = LLaMA(model, tokenizer)
print(f"done in {time.time() - _start_time:.2f} seconds")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main(ckpt_dir: str, tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95):
generator = load(ckpt_dir, tokenizer_path, 0, 1)
prompts = [input("Enter prompt: ")]
print("Starting generation with prompt:", prompts[0])
while True:
start_time = time.time()
results = generator.generate(
prompts, max_gen_len=30, temperature=temperature, top_p=top_p)
print(f"responded in {time.time() - start_time:.2f} seconds")
for result in results:
print(result)
print("\n==================================\n")
prompts = [input("Enter next prompt: ")]
if __name__ == "__main__":
main("chekpoint directory","tokenizer model path")
You can run LLaMA 7B with single GPU by referring to the above code. This can only be done with LLaMA 7B. If you want to run LLaMA 13B, 33B, 65B models, you can refer to the example.py file in this repository. You must install torch, fairescale, fire, sentencepiece libraries
```python from typing import Optional, Tuple from dataclasses import dataclass import math import torch from torch import nn import torch.nn.functional as F @dataclass class ModelArgs: dim: int = 512 n_layers: int = 8 n_heads: int = 8 vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 norm_eps: float = 1e-5 max_batch_size: int = 32 max_seq_len: int = 1024 class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) [: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_local_heads = args.n_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear( args.dim, args.n_heads * self.head_dim, bias=False, ) self.wk = nn.Linear( args.dim, args.n_heads * self.head_dim, bias=False, ) self.wv = nn.Linear( args.dim, args.n_heads * self.head_dim, bias=False, ) self.wo = nn.Linear( args.n_heads * self.head_dim, args.dim, bias=False, ) self.cache_k = torch.zeros( (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) ) self.cache_v = torch.zeros( (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) ) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) self.cache_k = self.cache_k.to(xq) self.cache_v = self.cache_v.to(xq) self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen] xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = output.transpose( 1, 2 ).contiguous().view(bsz, seqlen, -1) return self.wo(output) class FeedForward(nn.Module): def __init__( self, dim: int, hidden_dim: int, multiple_of: int, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * \ ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = torch.nn.Linear( dim, hidden_dim, bias=False, ) self.w2 = torch.nn.Linear( hidden_dim, dim, bias=False, ) self.w3 = torch.nn.Linear( dim, hidden_dim, bias=False, ) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = Attention(args) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of ) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) out = h + self.feed_forward.forward(self.ffn_norm(h)) return out class Transformer(nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers print("-- Creating embedding") self.tok_embeddings = torch.nn.Embedding( params.vocab_size, params.dim ) self.layers = torch.nn.ModuleList() print(f"-- Creating transformer blocks ({params.n_layers})") for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) print("-- Adding output layers ") self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear( params.dim, params.vocab_size, bias=False ) print("-- Precomputing frequencies") self.freqs_cis = precompute_freqs_cis( self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 ) @torch.inference_mode() def forward(self, tokens: torch.Tensor, start_pos: int): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen] mask = None if seqlen > 1: mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) for layer in self.layers: #print(f"-- Computing layer {layer.layer_id}") h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h) output = self.output(h[:, -1, :]) # only compute last logits return output.float()
from sentencepiece import SentencePieceProcessor from logging import getLogger from typing import List import os logger = getLogger() class Tokenizer: def __init__(self, model_path: str): # reload tokenizer assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) logger.info(f"Reloaded SentencePiece model from {model_path}") # BOS / EOS token IDs self.n_words: int = self.sp_model.vocab_size() self.bos_id: int = self.sp_model.bos_id() self.eos_id: int = self.sp_model.eos_id() self.pad_id: int = self.sp_model.pad_id() logger.info( f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" ) assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def encode(self, s: str, bos: bool, eos: bool) -> List[int]: assert type(s) is str t = self.sp_model.encode(s) if bos: t = [self.bos_id] + t if eos: t = t + [self.eos_id] return t def decode(self, t: List[int]) -> str: return self.sp_model.decode(t)
from typing import List import torch class LLaMA: def __init__(self, model: Transformer, tokenizer: Tokenizer): self.model = model self.tokenizer = tokenizer def generate( self, prompts: List[str], max_gen_len: int, temperature: float = 0.8, top_p: float = 0.95, ) -> List[str]: bsz = len(prompts) params = self.model.params assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) prompt_tokens = [self.tokenizer.encode( x, bos=True, eos=False) for x in prompts] min_prompt_size = min([len(t) for t in prompt_tokens]) max_prompt_size = max([len(t) for t in prompt_tokens]) total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) print(f"Forwarding {total_len} times") tokens = torch.full( (bsz, total_len), self.tokenizer.pad_id).long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id start_pos = min_prompt_size prev_pos = 0 for cur_pos in range(start_pos, total_len): # print(f"Feeding tensors forward #{cur_pos}") logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: probs = torch.softmax(logits / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token prev_pos = cur_pos decoded = [] for i, t in enumerate(tokens.tolist()): # cut to max gen len t = t[: len(prompt_tokens[i]) + max_gen_len] # cut to eos tok if any try: t = t[: t.index(self.tokenizer.eos_id)] except ValueError: pass decoded.append(self.tokenizer.decode(t)) return decoded def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token
from typing import Tuple import os import sys import torch import fire import time import json from pathlib import Path def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA: start_time = time.time() print("Locating checkpoints") checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert ( world_size == len(checkpoints) ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" print(f"Found MP={len(checkpoints)} checkpoints") ckpt_path = checkpoints[local_rank] print("Creating checkpoint instance...") checkpoint = torch.load(ckpt_path, map_location="cpu") print("Grabbing params...") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) print("Loading model arguments...") model_args: ModelArgs = ModelArgs( max_seq_len=1024, max_batch_size=32, **params) print("Creating tokenizer...") tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words print("Creating transformer...") #torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) print("Loading checkpoint to model...", end="") _start_time = time.time() #torch.set_default_tensor_type(torch.BFloat16Tensor) model.load_state_dict(checkpoint, strict=False) print(f"done in {time.time() - _start_time:.2f} seconds") _start_time = time.time() print("Creating LLaMA generator...", end="") generator = LLaMA(model, tokenizer) print(f"done in {time.time() - _start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds") return generator def main(ckpt_dir: str, tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95): generator = load(ckpt_dir, tokenizer_path, 0, 1) prompts = [input("Enter prompt: ")] print("Starting generation with prompt:", prompts[0]) while True: start_time = time.time() results = generator.generate( prompts, max_gen_len=30, temperature=temperature, top_p=top_p) print(f"responded in {time.time() - start_time:.2f} seconds") for result in results: print(result) print("\n==================================\n") prompts = [input("Enter next prompt: ")] if __name__ == "__main__": main("chekpoint directory","tokenizer model path")
You can run LLaMA 7B with single GPU by referring to the above code. This can only be done with LLaMA 7B. If you want to run LLaMA 13B, 33B, 65B models, you can refer to the example.py file in this repository. You must install torch, fairescale, fire, sentencepiece libraries
Thank you for this! I believe this only works with 7B because the world size and the MP must match always.
I simplified the above changes to the minimal set, and committed to github here for easier access: https://github.com/tbenst/llama/tree/jupyter
Thanks for contributing. Closing this issue.