TinyLlama-1.1B小型但強大的語言模型
TinyLlama 項目旨在在 3 萬億 tokens 上進行預(yù)訓練,構(gòu)建一個擁有11億參數(shù)的Llama模型。經(jīng)過精心優(yōu)化,"僅"需16塊A100-40G的GPU,便可在90天內(nèi)完成這個任務(wù)。訓練已于2023-09-01開始。
該項目采用了與Llama 2完全相同的架構(gòu)和分詞器。這意味著TinyLlama可以在許多基于Llama的開源項目中即插即用。此外,TinyLlama只有1.1B的參數(shù),體積小巧,適用于需要限制計算和內(nèi)存占用的多種應(yīng)用。
發(fā)布時間表
項目團隊會根據(jù)以下計劃逐步發(fā)布中間checkpoint。同時也列了一些基線模型進行比較。
| Date | HF Checkpoint | Tokens | Step | HellaSwag Acc_norm |
|---|---|---|---|---|
| Baseline | StableLM-Alpha-3B | 800B | -- | 38.31 |
| Baseline | Pythia-1B-intermediate-step-50k-105b | 105B | 50k | 42.04 |
| Baseline | Pythia-1B | 300B | 143k | 47.16 |
| 2023-09-04 | TinyLlama-1.1B-intermediate-step-50k-105b | 105B | 50k | 43.50 |
| 2023-09-16 | -- | 500B | -- | -- |
| 2023-10-01 | -- | 1T | -- | -- |
| 2023-10-16 | -- | 1.5T | -- | -- |
| 2023-10-31 | -- | 2T | -- | -- |
| 2023-11-15 | -- | 2.5T | -- | -- |
| 2023-12-01 | -- | 3T | -- | -- |
潛在場景
小型但強大的語言模型對許多應(yīng)用都很有用。以下是一些潛在的場景:
- 幫助對大型模型進行speculative decoding。
- 在邊緣裝置上運行,比如離線的實時機器翻譯 (TinyLlama的4比特量化版本的模型權(quán)重只需要550MB的內(nèi)存)。
- 在游戲中實現(xiàn)實時對話生成(因為還得給游戲本身留顯存所以模型要小)。
此外,項目代碼可以給初學者做一個入門預(yù)訓練的簡潔參考。如果你要訓練50億以下參數(shù)的語言模型, 你其實不需要Megatron-LM。
訓練細節(jié)
以下是訓練設(shè)置的一些細節(jié):
| Setting | Description |
|---|---|
| Parameters | 1.1B |
| Attention Variant | Grouped Query Attention |
| Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632 |
| Sequence Length | 2048 |
| Batch Size | 2 million tokens (2048 * 1024) |
| Learning Rate | 4e-4 |
| Learning Rate Schedule | Cosine with 2000 warmup steps |
| Training Data | Slimpajama & Starcoderdata |
| Data Preprocessing | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata |
| Combined Dataset Size | Around 950B tokens |
| Total Tokens During Training | 3 trillion (slightly more than 3 epochs/143k steps) |
| Natural Language to Code Ratio | 7:3 |
| Hardware | 16 A100-40G GPUs |
速度極快
代碼庫支持以下特性:
- multi-gpu and multi-node distributed training with FSDP.
- flash attention 2.
- fused layernorm.
- fused swiglu.
- fused cross entropy loss .
- fused rotary positional embedding.
有了這些優(yōu)化,可以達到24k tokens/秒/A100的訓練速度,也就是56%的MFU(在A100-80G上的MFU會更高)。這個速度可以讓你可以在8個A100上用32小時訓練一個chinchilla-optimial的模型(11億參數(shù),220億token)。這些優(yōu)化也大大減少了顯存占用,可以把11億參數(shù)的模型塞入40GB的GPU里面還能同時維持16k tokens的per-gpu batch size。只需要把batch size改小一點, 你就可以在RTX 3090/4090上面訓練TinyLlama。
下面是其代碼庫與Pythia和MPT的訓練速度的比較。
| Model | A100 GPU hours taken on 300B tokens |
|---|---|
| TinyLlama-1.1B | 3456 |
| Pythia-1.0B | 4830 |
| MPT-1.3B | 7920 |
Pythia的數(shù)字來自他們的論文。MPT的數(shù)字來自這里,作者說MPT-1.3B"was trained on 440 A100-40GBs for about half a day" on 200B tokens。
TinyLlama是一個相對較小的模型,使用了GQA,這意味著它在推理期間也很快。以下是測量的一些推理速度:
| Framework | Device | Settings | Throughput (tokens/sec) |
|---|---|---|---|
| Llama.cpp | Mac M2 16GB RAM | batch_size=1; 4-bit inference | 71.8 |
| vLLM | A40 GPU | batch_size=100, n=10 | 7094.5 |
