<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          ChatRWKV 學(xué)習(xí)筆記和使用指南

          共 72469字,需瀏覽 145分鐘

           ·

          2023-09-01 02:18

          ec18d60b10e89220e5fab719af8118d6.webp在這里插入圖片描述0x0. 前言

          Receptance Weighted Key Value(RWKV)是pengbo提出的一個(gè)新的語(yǔ)言模型架構(gòu),它使用了線性的注意力機(jī)制,把Transformer的高效并行訓(xùn)練與RNN的高效推理相結(jié)合,使得模型在訓(xùn)練期間可以并行,并在推理的時(shí)候保持恒定的計(jì)算和內(nèi)存復(fù)雜度。目前RWKV的社區(qū)已經(jīng)非常火了,我們從huggingface上可以看到RWKV已經(jīng)訓(xùn)練了多個(gè)百億參數(shù)的模型,特別是RWKV World模型支持世界所有語(yǔ)言的生成+對(duì)話+任務(wù)+代碼,功能十分全面。此外還有很多開(kāi)發(fā)者基于RWKV的微調(diào)模型。

          6d8aae02797c9aeebe470198fca61c6b.webp在部署方面RWKV社區(qū)也取得了長(zhǎng)足的發(fā)展,例如ChatRWKV,rwkv.cpp,RWKV-Runner,rwkv-cpp-cuda,同時(shí)包括mlc-llm,tgi等都支持了RWKV相關(guān)模型,社區(qū)進(jìn)展非常快。本文嘗試從ChatRWKV項(xiàng)目入門學(xué)習(xí)一下RWKV。

          RWKV論文原文:https://arxiv.org/abs/2305.13048

          0x1. 學(xué)習(xí)資料

          這里列出一些我看到的學(xué)習(xí)資料,方便感興趣的讀者學(xué)習(xí):

          • [雙字] 在{Transformer}時(shí)代, {RWKV}是RNN的[文藝復(fù)興]--論文詳解(https://www.bilibili.com/video/BV11N411C76z/?spm_id_from=333.337.search-card.all.click&vd_source=4dffb0fbabed4311f4318e8c6d253a10)
          • 野心勃勃的RNN——RWKV語(yǔ)言模型及其100行代碼極簡(jiǎn)實(shí)現(xiàn)(https://zhuanlan.zhihu.com/p/620469303)
          • github倉(cāng)庫(kù)(https://github.com/BlinkDL/RWKV-LM)
          • rwkv論文原理解讀(https://www.zhihu.com/question/602564718)
          • RWKV的微調(diào)教學(xué),以及RWKV World:支持世界所有語(yǔ)言的生成+對(duì)話+任務(wù)+代碼(https://zhuanlan.zhihu.com/p/638326262)
          • RWKV:用RNN達(dá)到Transformer性能,且支持并行模式和長(zhǎng)程記憶,既快又省顯存,已在14B參數(shù)規(guī)模檢驗(yàn)(https://zhuanlan.zhihu.com/p/599150009)
          • 談?wù)?RWKV 系列的 prompt 設(shè)計(jì),模型選擇,解碼參數(shù)設(shè)置(https://zhuanlan.zhihu.com/p/639629050)
          • RWKV進(jìn)展:一鍵生成論文,純CPU高速INT4,純CUDA脫離pytorch,ctx8192不耗顯存不變慢(https://zhuanlan.zhihu.com/p/626083366)
          • 開(kāi)源1.5/3/7B中文小說(shuō)模型:顯存3G就能跑7B模型,幾行代碼即可調(diào)用(https://zhuanlan.zhihu.com/p/609154637)
          • 發(fā)布幾個(gè)RWKV的Chat模型(包括英文和中文)7B/14B歡迎大家玩(https://zhuanlan.zhihu.com/p/618011122)
          • 實(shí)例:手寫 CUDA 算子,讓 Pytorch 提速 20 倍(某特殊算子)(https://zhuanlan.zhihu.com/p/476297195)
          • BlinkDL/RWKV-World-7B gradio demo(https://huggingface.co/spaces/BlinkDL/RWKV-World-7B/tree/main)
          • ChatRWKV(有可用貓娘模型!)微調(diào)/部署/使用/訓(xùn)練資源合集(https://zhuanlan.zhihu.com/p/616351661)
          • pengbo的專欄(https://www.zhihu.com/people/bopengbopeng/posts)

          原理推薦看第一個(gè)和第二個(gè)鏈接,其它的有選擇觀看,我這里就以ChatRWKV項(xiàng)目的解析為例來(lái)入門RWKV。

          0x2. RWKV in 150 lines

          下面這個(gè)文件 https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py 以150行代碼實(shí)現(xiàn)了RWKV-4-Pile-430M這個(gè)模型,是學(xué)習(xí)RWKV的最佳代碼,所以讓這一節(jié)就是來(lái)逐步解析一下這個(gè)代碼。分析代碼之前先對(duì)RWKV這個(gè)名字的含義和組成RWKV模型2個(gè)關(guān)鍵的元素Time Mixing和Channel Mixing簡(jiǎn)單描述一下,詳細(xì)的原理還是請(qǐng)參考原始論文和第一節(jié)學(xué)習(xí)資料的第一個(gè)視頻鏈接和第四個(gè)原理和公式詳解的文字鏈接。

          0x2.1 RWKV名字含義

          論文中對(duì)名字有說(shuō)明:

          • R: Receptance vector acting as the acceptance of past information. 類似于LSTM的“門控單元”
          • W: Weight is the positional weight decay vector. A trainable model parameter. 可學(xué)習(xí)的位置權(quán)重衰減向量,什么叫“位置權(quán)重衰減”看下面的公式(14)
          • K: Key is a vector analogous to K in traditional attention. 與傳統(tǒng)自注意力機(jī)制
          • V : Value is a vector analogous to V in traditional attention. 與傳統(tǒng)自注意力機(jī)制相同

          0x2.2 RWKV模型架構(gòu)

          RWKV模型由一系列RWKV Block模塊堆疊而成,RWKV Block的結(jié)構(gòu)如下圖所示:a94758f9b6c28769217b4d4167c2664a.webp

          RWKV Block又主要由Time Mixing和Channel Mixing組成。

          Time Mixing模塊的公式定義如下:5ef3f8a339c9ba82a82f315c30c80764.webp

          這里的表示當(dāng)前時(shí)刻,看成當(dāng)前的token,而看成前一個(gè)token,、、的計(jì)算與傳統(tǒng)Attention機(jī)制類似,通過(guò)將當(dāng)前輸入token與前一時(shí)刻輸入token做線性插值,體現(xiàn)了recurrence的特性。然后的計(jì)算則是對(duì)應(yīng)注意力機(jī)制的實(shí)現(xiàn),這個(gè)實(shí)現(xiàn)也是一個(gè)過(guò)去時(shí)刻信息與當(dāng)前時(shí)刻信息的線性插值,注意到這里是指數(shù)形式并且當(dāng)前token和之前的所有token都有一個(gè)指數(shù)衰減求和的關(guān)系,也正是因?yàn)檫@樣讓擁有了線性attention的特性。

          然后RWKV模型里面除了使用Time Mixing建模這種Token間的關(guān)系之外,在Token內(nèi)對(duì)應(yīng)的隱藏層維度上RWKV也進(jìn)行了建模,即通過(guò)Channel Mixing模塊。

          b739a9c8983b88f838e50796f1e6a5fd.webp在這里插入圖片描述

          Channel Mixing的意思就是在特征維度上做融合。假設(shè)特征向量維度是d,那么每一個(gè)維度的元素都要接收其他維度的信息,來(lái)更新它自己。特征向量的每個(gè)維度就是一個(gè)“channel”(通道)。

          下圖展示了RWKV模型整體的結(jié)構(gòu):

          804b1f178143f60f2f1f6dd0393438fa.webp在這里插入圖片描述

          這里提到的token shift就是上面對(duì)r, k, v計(jì)算的時(shí)候類似于卷積滑窗的過(guò)程。然后我們可以看到當(dāng)前的token不僅僅剋呀通過(guò)Time Mixing的token shit和隱藏狀態(tài)States(即)和之前的token建立聯(lián)系,也可以通過(guò)Channel Mixing的token shift和之前的token建立聯(lián)系,類似于擁有了全局感受野

          0x2.3 RWKV_in_150_lines.py 解析

          初始化部分

          首先來(lái)看RWKV模型初始化部分以及最開(kāi)始的一些準(zhǔn)備工作:

                
                ########################################################################################################
          # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
          ########################################################################################################

          # 導(dǎo)入庫(kù)
          import numpy as np
          # 這行代碼設(shè)置了numpy數(shù)組的打印格式,其中precision=4表示小數(shù)點(diǎn)后保留4位,
          # suppress=True表示抑制小數(shù)點(diǎn)的科學(xué)計(jì)數(shù)法表示,linewidth=200表示每行的字符寬度為200。
          np.set_printoptions(precision=4, suppress=True, linewidth=200)
          import types, torch
          from torch.nn import functional as F
          from tokenizers import Tokenizer

          # 加載一個(gè)分詞器
          tokenizer = Tokenizer.from_file("20B_tokenizer.json")

          # 使用types.SimpleNamespace()創(chuàng)建一個(gè)簡(jiǎn)單的命名空間對(duì)象args,并為其設(shè)置以下屬性:
          args = types.SimpleNamespace()
          # 模型的路徑。
          args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066'
          args.n_layer = 24 # 模型的層數(shù)。
          args.n_embd = 1024 # 模型的嵌入維度。

          # 定義了需要續(xù)寫的字符串,描述了科學(xué)家在西藏的一個(gè)偏遠(yuǎn)山谷中發(fā)現(xiàn)了一群會(huì)說(shuō)流利中文的龍的情況。
          context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
          NUM_TRIALS = 3 # 嘗試生成文本的次數(shù)。
          LENGTH_PER_TRIAL = 100 # 每次嘗試生成的文本長(zhǎng)度。
          TEMPERATURE = 1.0 # 控制生成文本的隨機(jī)性的參數(shù)。值越大,生成的文本越隨機(jī);值越小,生成的文本越確定。
          TOP_P = 0.85 # 在生成文本時(shí),只考慮累積概率超過(guò)此值的詞匯。

          ########################################################################################################

          class RWKV_RNN(torch.jit.ScriptModule):
              def __init__(self, args):
                  super().__init__()
                  # 將傳入的args參數(shù)賦值給類的屬性args。
                  self.args = args
                  # 將模型設(shè)置為評(píng)估模式,這意味著模型中的dropout和batchnorm將被禁用。
                  self.eval() # set torch to inference mode
                  
                  # 從指定路徑加載模型權(quán)重,并確保權(quán)重被加載到CPU上。
                  w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
                  # 這幾行代碼對(duì)加載的權(quán)重進(jìn)行了處理。它們檢查權(quán)重的鍵名,并根據(jù)鍵名對(duì)權(quán)重進(jìn)行不同的操作。
                  for k in w.keys():
                      if      '.time_' in k: w[k] = w[k].squeeze()
                      if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
                      else: w[k] = w[k].float() # convert to f32 type
                  
                  # 創(chuàng)建一個(gè)新的命名空間對(duì)象,并將其賦值給self.w。
                  self.w = types.SimpleNamespace() # set self.w from w
                  # 在self.w中創(chuàng)建一個(gè)名為blocks的字典。
                  self.w.blocks = {}
                  # for k in w.keys(): - 遍歷字典w的所有鍵。注釋中的例子 
                  # "blocks.0.att.time_first" => self.w.blocks[0].att.time_first" 
                  # 說(shuō)明了代碼的目標(biāo):將點(diǎn)分隔的鍵轉(zhuǎn)換為嵌套的屬性訪問(wèn)。
                  for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
                      parts = k.split('.'#  使用.作為分隔符將鍵k分割成多個(gè)部分,并將結(jié)果存儲(chǔ)在parts列表中。
                      last = parts.pop() # 從parts列表中彈出最后一個(gè)元素并存儲(chǔ)在last中。這將是要設(shè)置的屬性的名稱。
                      #  初始化一個(gè)變量here,它將用于遍歷或創(chuàng)建self.w中的嵌套命名空間。
                      here = self.w
                      # 遍歷parts列表中的每個(gè)部分。
                      for p in parts:
                          # 檢查當(dāng)前部分p是否是數(shù)字。
                          if p.isdigit():
                              p = int(p)
                              # 如果當(dāng)前數(shù)字鍵p不在here中,則在here中為其創(chuàng)建一個(gè)新的命名空間。
                              if p not in here: here[p] = types.SimpleNamespace()
                              # 更新here以指向新創(chuàng)建的或已存在的命名空間。
                              here = here[p]
                          # 如果當(dāng)前部分p不是數(shù)字。
                          else:
                              # 如果here沒(méi)有名為p的屬性,則為其創(chuàng)建一個(gè)新的命名空間。
                              if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                              here = getattr(here, p)
                      setattr(here, last, w[k])

          這部分除了準(zhǔn)備一些模型執(zhí)行需要的超參數(shù)之外,還對(duì)RWKV模型進(jìn)行了初始化,值得注意的是在初始化過(guò)程中會(huì)加載RWKV模型的權(quán)重到w這個(gè)字典里面,然后遍歷字典w的所有鍵。注釋中的例子"blocks.0.att.time_first" => self.w.blocks[0].att.time_first" 說(shuō)明了代碼的目標(biāo):將點(diǎn)分隔的鍵轉(zhuǎn)換為嵌套的屬性訪問(wèn)。后面推理的時(shí)候?qū)?huì)直接訪問(wèn)self.w這個(gè)處理之后的權(quán)重對(duì)象。

          RWKV 模型通道融合函數(shù)(Channel mixing)

                
                @torch.jit.script_method
              def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
                  xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
                  xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
                  state[5*i+0] = x
                  r = torch.sigmoid(rw @ xr)
                  k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
                  return r * (vw @ k)

          參考Channel Mixing的公式來(lái)看:

          b739a9c8983b88f838e50796f1e6a5fd.webp在這里插入圖片描述

          在channel_mixing函數(shù)里面,對(duì)應(yīng)當(dāng)前token的詞嵌入向量,表示前一個(gè)token的詞嵌入向量。剩下的變量都是RWKV的可學(xué)習(xí)參數(shù)。然后代碼里面會(huì)動(dòng)態(tài)更新state,讓總是當(dāng)前token的前一個(gè)token的詞嵌入。

          RWKV Time mixing函數(shù)

                
                @torch.jit.script_method
              def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
                  xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
                  xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
                  xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
                  state[5*i+1] = x
                  r = torch.sigmoid(rw @ xr)
                  k = kw @ xk
                  v = vw @ xv
                  
                  aa = state[5*i+2]
                  bb = state[5*i+3]
                  pp = state[5*i+4]
                  ww = time_first + k
                  qq = torch.maximum(pp, ww)
                  e1 = torch.exp(pp - qq)
                  e2 = torch.exp(ww - qq)
                  a = e1 * aa + e2 * v
                  b = e1 * bb + e2
                  wkv = a / b
                  ww = pp + time_decay
                  qq = torch.maximum(ww, k)
                  e1 = torch.exp(ww - qq)
                  e2 = torch.exp(k - qq)
                  state[5*i+2] = e1 * aa + e2 * v
                  state[5*i+3] = e1 * bb + e2
                  state[5*i+4] = qq
                  return ow @ (r * wkv)

          仍然是要對(duì)照公式來(lái)看:

          5ef3f8a339c9ba82a82f315c30c80764.webp然后這里有一個(gè)trick,就是對(duì)的計(jì)算可以寫成RNN的遞歸形式:e92d555f5d489663e044e4f6b85361a2.webp這樣上面的公式就很清晰了,還需要注意的是在實(shí)現(xiàn)的時(shí)候由于有exp的存在,為了保證數(shù)值穩(wěn)定性實(shí)現(xiàn)的時(shí)候減去了每個(gè)公式涉及到的e的指數(shù)部分的Max。

          關(guān)于RWKV 的attention部分()計(jì)算如果你有細(xì)節(jié)不清楚,建議觀看一下這個(gè)視頻:解密RWKV線性注意力的進(jìn)化過(guò)程(https://www.bilibili.com/video/BV1zW4y1D7Qg/?spm_id_from=333.337.search-card.all.click&vd_source=4dffb0fbabed4311f4318e8c6d253a10) 。

          RWKV model forward函數(shù)

                
                # 定義forward方法,它接受兩個(gè)參數(shù):token和state。
          def forward(self, token, state):
                  # 這是一個(gè)上下文管理器,確保在此代碼塊中不會(huì)計(jì)算任何梯度。
                  # 這通常用于評(píng)估模式,以提高性能并避免不必要的計(jì)算。
                  with torch.no_grad():
                      # 如果state為None,則初始化state為一個(gè)全零張量。
                      # 其形狀由self.args.n_layer和self.args.n_embd確定。
                      if state == None:
                          state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
                          # 遍歷每一層,并將state的特定位置設(shè)置為-1e30(表示負(fù)無(wú)窮大)。
                          for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
                      # 使用token索引self.w.emb.weight,獲取詞嵌入向量。
                      x = self.w.emb.weight[token]
                      # 對(duì)獲取的詞嵌入向量x應(yīng)用層歸一化。
                      x = self.layer_norm(x, self.w.blocks[0].ln0)
                      
                      for i in range(self.args.n_layer):
                          att = self.w.blocks[i].att # 獲取當(dāng)前層的注意力參數(shù)
                          # 這些行使用time_mixing方法對(duì)x進(jìn)行處理,并將結(jié)果加到x上。
                          x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                              att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, 
                              att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
                          ffn = self.w.blocks[i].ffn # 獲取當(dāng)前層的前饋網(wǎng)絡(luò)參數(shù)。
                          # 使用channel_mixing方法對(duì)x進(jìn)行處理,并將結(jié)果加到x上。
                          x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                              ffn.time_mix_k, ffn.time_mix_r, 
                              ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
                      
                      # 對(duì)x應(yīng)用最后的層歸一化,并與self.w.head.weight進(jìn)行矩陣乘法。
                      x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
                      return x.float(), state

          從這里可以看到RWKV的state只和層數(shù)和嵌入層維度有關(guān)系,和序列長(zhǎng)度是無(wú)關(guān)的,這是推理時(shí)相比于RNN的核心優(yōu)點(diǎn)。

          采樣函數(shù)

                
                # 這段代碼是一個(gè)用于生成隨機(jī)樣本的函數(shù)。
          # 這是一個(gè)函數(shù)定義,函數(shù)名為 sample_logits,接受三個(gè)參數(shù) out、temperature 
          # 和 top_p,其中 temperature 默認(rèn)值為 1.0,top_p 默認(rèn)值為 0.8。
          def sample_logits(out, temperature=1.0, top_p=0.8):
              # 這行代碼使用 softmax 函數(shù)對(duì) out 進(jìn)行操作,將輸出轉(zhuǎn)換為概率分布。
              # dim=-1 表示在最后一個(gè)維度上進(jìn)行 softmax 操作。.numpy() 將結(jié)果轉(zhuǎn)換為 NumPy 數(shù)組。
              probs = F.softmax(out, dim=-1).numpy()
              
              # 這行代碼使用 NumPy 的 np.sort 函數(shù)對(duì)概率分布進(jìn)行排序,
              # 并通過(guò) [::-1] 實(shí)現(xiàn)降序排列。結(jié)果保存在 sorted_probs 變量中。
              sorted_probs = np.sort(probs)[::-1]
              # 這行代碼計(jì)算累積概率,使用 NumPy 的 np.cumsum 函數(shù)對(duì) sorted_probs 
              # 進(jìn)行累加操作。結(jié)果保存在 cumulative_probs 變量中。
              cumulative_probs = np.cumsum(sorted_probs)
              # 這行代碼通過(guò)比較 cumulative_probs 是否大于 top_p 來(lái)找到概率分布中的截?cái)帱c(diǎn)。
              # np.argmax 返回第一個(gè)滿足條件的索引,float() 將其轉(zhuǎn)換為浮點(diǎn)數(shù)并保存在 cutoff 變量中。
              cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
              # 這行代碼將低于 cutoff 的概率值設(shè)為 0,即將概率分布中小于截?cái)帱c(diǎn)的概率置零。
              probs[probs < cutoff] = 0
              # 這段代碼根據(jù) temperature 的取值對(duì)概率分布進(jìn)行調(diào)整。
              # 如果 temperature 不等于 1.0,則將概率分布的每個(gè)元素取倒數(shù)的 1.0 / temperature 次冪。
              if temperature != 1.0:
                  probs = probs.pow(1.0 / temperature)
              # 這行代碼將概率分布?xì)w一化,確保所有概率的總和為 1。
              probs = probs / np.sum(probs)
              # 這行代碼使用 np.random.choice 函數(shù)根據(jù)概率分布 probs 生成一個(gè)隨機(jī)樣本,
              # a=len(probs) 表示可選的樣本范圍為 probs 的長(zhǎng)度,p=probs 表示每個(gè)樣本被選中的概率。
              out = np.random.choice(a=len(probs), p=probs)
              # 函數(shù)返回生成的隨機(jī)樣本。
              return out

          這個(gè)函數(shù)的作用是根據(jù)給定的概率分布 out 生成一個(gè)隨機(jī)樣本,通過(guò)調(diào)整溫度 temperature 和頂部概率 top_p 來(lái)控制生成的樣本的多樣性和穩(wěn)定性。

          生成文本的流程

                
                # 打印使用 CPU 加載模型的信息,其中 args.MODEL_NAME 是模型名稱。
          print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
          # 創(chuàng)建一個(gè)名為 model 的 RWKV_RNN 模型實(shí)例,參數(shù)為 args。
          model = RWKV_RNN(args)

          # 打印預(yù)處理上下文信息的提示,提示使用的是較慢的版本。然后初始化 init_state 為 None。
          print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
          init_state = None
          # 對(duì)上下文進(jìn)行分詞編碼,并使用模型的 forward 方法逐個(gè)處理分詞編碼的 tokens,
          # 將結(jié)果保存在 init_out 和 init_state 中。
          for token in tokenizer.encode(context).ids:
              init_out, init_state = model.forward(token, init_state)

          # 使用循環(huán)進(jìn)行多次試驗(yàn)(NUM_TRIALS 次)。
          for TRIAL in range(NUM_TRIALS):
              # 在每次試驗(yàn)的開(kāi)始打印試驗(yàn)信息和上下文。創(chuàng)建一個(gè)空列表 all_tokens 用于保存生成的 tokens。
              print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
              all_tokens = []
              # 初始化變量 out_last 為 0,out 和 state 分別為 init_out 和 init_state 的克隆。
              out_last = 0
              out, state = init_out.clone(), init_state.clone()
              # 在每個(gè)試驗(yàn)中,使用循環(huán)生成 LENGTH_PER_TRIAL 個(gè) tokens。
              for i in range(LENGTH_PER_TRIAL):
                  # 調(diào)用 sample_logits 函數(shù)生成一個(gè)隨機(jī) token,并將其添加到 all_tokens 列表中。
                  token = sample_logits(out, TEMPERATURE, TOP_P)
                  all_tokens += [token]
                  # 使用 tokenizer.decode 將 all_tokens[out_last:] 解碼為文本,
                  # 并檢查解碼結(jié)果是否包含無(wú)效的 utf-8 字符('\ufffd')。如果結(jié)果有效,則將其打印出來(lái)。
                  tmp = tokenizer.decode(all_tokens[out_last:])
                  if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
                      print(tmp, end="", flush=True)
                      out_last = i + 1
                  # 調(diào)用模型的 forward 方法,將生成的 token 和當(dāng)前的狀態(tài)傳遞給模型,獲取更新的 out 和 state。
                  out, state = model.forward(token, state)       
          print('\n')

          這段代碼的目的是使用 RWKV_RNN 模型生成文本,模型通過(guò) sample_logits 函數(shù)生成隨機(jī) token,然后將其傳遞給模型進(jìn)行預(yù)測(cè),并根據(jù)預(yù)測(cè)結(jié)果生成下一個(gè) token,不斷重復(fù)這個(gè)過(guò)程直到生成指定數(shù)量的 tokens。生成的文本會(huì)被打印出來(lái)。

          0x3. ChatRWKV v2聊天系統(tǒng)邏輯實(shí)現(xiàn)解析

          ChatRWKV的README中提到v2版本實(shí)現(xiàn)了一些新功能,建議我們使用,v2版本的代碼在 https://github.com/BlinkDL/ChatRWKV/tree/main/v2 。我們這一節(jié)就對(duì)這部分的代碼做一個(gè)解讀。

          chat.py解析

          https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py是ChatRWKV v2的核心實(shí)現(xiàn),我們直接來(lái)看這個(gè)文件。

                
                ########################################################################################################
          # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
          ########################################################################################################

          # os、copy、types、gc、sys:這些是Python標(biāo)準(zhǔn)庫(kù),
          # 用于操作系統(tǒng)功能、對(duì)象復(fù)制、類型管理、垃圾回收和系統(tǒng)特定的參數(shù)和函數(shù)。
          import os, copy, types, gc, sys
          current_path = os.path.dirname(os.path.abspath(__file__))
          sys.path.append(f'{current_path}/../rwkv_pip_package/src')

          import numpy as np
          # prompt_toolkit中的prompt:用于構(gòu)建命令行界面的庫(kù)。
          from prompt_toolkit import prompt
          # 腳本嘗試根據(jù)傳遞給腳本的命令行參數(shù)設(shè)置CUDA_VISIBLE_DEVICES環(huán)境變量。
          try:
              os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
          except:
              pass
          # 調(diào)用np.set_printoptions()函數(shù)設(shè)置NumPy數(shù)組的打印選項(xiàng)。它配置精度、抑制小值和輸出行的最大寬度。
          np.set_printoptions(precision=4, suppress=True, linewidth=200)
          args = types.SimpleNamespace()

          # 腳本打印有關(guān)ChatRWKV版本的信息。
          print('\n\nChatRWKV v2 https://github.com/BlinkDL/ChatRWKV')

          import torch
          # 針對(duì)PyTorch設(shè)置了幾個(gè)配置,以優(yōu)化其在GPU上的性能。
          torch.backends.cudnn.benchmark = True
          torch.backends.cudnn.allow_tf32 = True
          torch.backends.cuda.matmul.allow_tf32 = True

          # 有一些被注釋的行代表不同的設(shè)置,可以嘗試優(yōu)化性能。這些設(shè)置與PyTorch的即時(shí)編譯(JIT)和融合有關(guān)。
          # Tune these below (test True/False for all of them) to find the fastest setting:
          # torch._C._jit_set_profiling_executor(True)
          # torch._C._jit_set_profiling_mode(True)
          # torch._C._jit_override_can_fuse_on_cpu(True)
          # torch._C._jit_override_can_fuse_on_gpu(True)
          # torch._C._jit_set_texpr_fuser_enabled(False)
          # torch._C._jit_set_nvfuser_enabled(False)

          ########################################################################################################
          #
          # 有一些注釋解釋了不同的模型精度選項(xiàng)(fp16、fp32、bf16、xxxi8)及其影響。
          # fp16 = good for GPU (!!! DOES NOT support CPU !!!)
          # fp32 = good for CPU
          # bf16 = less accuracy, supports some CPUs
          # xxxi8 (example: fp16i8) = xxx with int8 quantization to save 50% VRAM/RAM, slightly less accuracy
          #
          # Read https://pypi.org/project/rwkv/ for Strategy Guide
          #
          ########################################################################################################

          這段代碼設(shè)置了必要的依賴項(xiàng),配置了GPU使用環(huán)境,并提供了有關(guān)RWKV語(yǔ)言模型及其精度選項(xiàng)的信息。

                
                # 這個(gè)變量用于設(shè)置模型推理的策略。在代碼中有幾個(gè)不同的策略選項(xiàng)被注釋掉了,
          # 而實(shí)際選擇的策略是 'cuda fp16',表示使用CUDA加速并使用半精度浮點(diǎn)數(shù)進(jìn)行計(jì)算。
          # args.strategy = 'cpu fp32'
          args.strategy = 'cuda fp16'
          # args.strategy = 'cuda:0 fp16 -> cuda:1 fp16'
          # args.strategy = 'cuda fp16i8 *10 -> cuda fp16'
          # args.strategy = 'cuda fp16i8'
          # args.strategy = 'cuda fp16i8 -> cpu fp32 *10'
          # args.strategy = 'cuda fp16i8 *10+'

          # 這兩個(gè)變量設(shè)置了環(huán)境變量。RWKV_JIT_ON 控制是否啟用即時(shí)編譯(JIT),
          # RWKV_CUDA_ON 控制是否編譯CUDA內(nèi)核。在代碼中,RWKV_CUDA_ON 被設(shè)置為0,即不編譯CUDA內(nèi)核。
          os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
          os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

          # 這個(gè)變量設(shè)置了聊天系統(tǒng)使用的語(yǔ)言,可以選擇英語(yǔ)('English')、中文('Chinese')或日語(yǔ)('Japanese')。
          CHAT_LANG = 'English' # English // Chinese // more to come

          # 這個(gè)變量設(shè)置了模型的名稱和路徑。根據(jù)選擇的語(yǔ)言不同,會(huì)有不同的模型名稱被設(shè)置。
          # 模型名稱指定了模型文件的位置,以及其他一些相關(guān)信息。
          # Download RWKV models from https://huggingface.co/BlinkDL
          # Use '/' in model path, instead of '\'
          # Use convert_model.py to convert a model for a strategy, for faster loading & saves CPU RAM 
          if CHAT_LANG == 'English':
              args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-raven/RWKV-4-Raven-14B-v12-Eng98%-Other2%-20230523-ctx8192'
              # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-raven/RWKV-4-Raven-7B-v12-Eng98%-Other2%-20230521-ctx8192'
              # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230313-ctx8192-test1050'

          elif CHAT_LANG == 'Chinese'# Raven系列可以對(duì)話和 +i 問(wèn)答。Novel系列是小說(shuō)模型,請(qǐng)只用 +gen 指令續(xù)寫。
              args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-raven/RWKV-4-Raven-7B-v12-Eng49%-Chn49%-Jpn1%-Other1%-20230530-ctx8192'
              # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-world/RWKV-4-World-CHNtuned-3B-v1-20230625-ctx4096'
              # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-novel/RWKV-4-Novel-7B-v1-ChnEng-20230426-ctx8192'

          elif CHAT_LANG == 'Japanese':
              # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-raven/RWKV-4-Raven-14B-v8-EngAndMore-20230408-ctx4096'
              args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-raven/RWKV-4-Raven-7B-v10-Eng89%-Jpn10%-Other1%-20230420-ctx4096'

          # -1.py for [User & Bot] (Q&A) prompt
          # -2.py for [Bob & Alice] (chat) prompt
          # 這個(gè)變量設(shè)置了模型的名稱和路徑。根據(jù)選擇的語(yǔ)言不同,會(huì)有不同的模型名稱被設(shè)置。
          # 模型名稱指定了模型文件的位置,以及其他一些相關(guān)信息。
          PROMPT_FILE = f'{current_path}/prompt/default/{CHAT_LANG}-2.py'

          # 代碼中還包含了一些其他參數(shù)的設(shè)置,如聊天長(zhǎng)度的限制、生成文本的參數(shù)(溫度、top-p值等)、重復(fù)懲罰等。
          CHAT_LEN_SHORT = 40
          CHAT_LEN_LONG = 150
          FREE_GEN_LEN = 256

          # For better chat & QA quality: reduce temp, reduce top-p, increase repetition penalties
          # Explanation: https://platform.openai.com/docs/api-reference/parameter-details
          # 這個(gè)變量用于控制生成文本的溫度。通過(guò)調(diào)整溫度值,可以控制生成文本的隨機(jī)性和多樣性。
          # 在代碼中設(shè)置為1.2,表示較高的溫度,可以增加生成文本的多樣性。
          GEN_TEMP = 1.2 # It could be a good idea to increase temp when top_p is low
          # 這個(gè)變量用于控制生成文本的top-p值。Top-p是一種生成文本的策略,
          # 它限制了生成文本中概率最高的單詞的累積概率。
          # 通過(guò)減小top-p值,可以提高生成文本的準(zhǔn)確性和一致性。在代碼中設(shè)置為0.5,表示較低的top-p值。
          GEN_TOP_P = 0.5 # Reduce top_p (to 0.5, 0.2, 0.1 etc.) for better Q&A accuracy (and less diversity)
          # 這兩個(gè)變量分別控制生成文本中重復(fù)內(nèi)容的懲罰權(quán)重。
          # GEN_alpha_presence 控制了存在性懲罰的權(quán)重,即生成文本中重復(fù)內(nèi)容的懲罰程度。
          # GEN_alpha_frequency 控制了頻率懲罰的權(quán)重,即生成文本中連續(xù)重復(fù)內(nèi)容的懲罰程度。
          # 在代碼中,它們都被設(shè)置為0.4。
          GEN_alpha_presence = 0.4 # Presence Penalty
          GEN_alpha_frequency = 0.4 # Frequency Penalty
          # 這個(gè)變量控制了重復(fù)懲罰的衰減率。通過(guò)減小衰減率,可以使重復(fù)懲罰在生成文本的后續(xù)部分逐漸減弱。
          # 在代碼中設(shè)置為0.996。
          GEN_penalty_decay = 0.996
          # 這個(gè)變量設(shè)置了一些標(biāo)點(diǎn)符號(hào),用于表示要避免在生成文本中重復(fù)的內(nèi)容。
          # 在代碼中,它包含了中文的逗號(hào)、冒號(hào)、問(wèn)號(hào)和感嘆號(hào)。
          AVOID_REPEAT = ',:?!'

          # 這個(gè)變量用于將輸入分成多個(gè)塊,以節(jié)省顯存(VRAM)。在代碼中設(shè)置為256。
          CHUNK_LEN = 256 # split input into chunks to save VRAM (shorter -> slower)

          # 這個(gè)變量包含了模型的名稱和路徑。根據(jù)代碼中的注釋,可以看到有幾個(gè)不同的模型路徑被設(shè)置。
          # 根據(jù)具體情況,會(huì)選擇其中一個(gè)模型路徑。
          # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-world/RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096'
          # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-world/RWKV-4-World-CHNtuned-0.4B-v1-20230618-ctx4096'
          # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-world/RWKV-4-World-3B-v1-20230619-ctx4096'

          if args.MODEL_NAME.endswith('/'): # for my own usage
              if 'rwkv-final.pth' in os.listdir(args.MODEL_NAME):
                  args.MODEL_NAME = args.MODEL_NAME + 'rwkv-final.pth'
              else:
                  latest_file = sorted([x for x in os.listdir(args.MODEL_NAME) if x.endswith('.pth')], key=lambda x: os.path.getctime(os.path.join(args.MODEL_NAME, x)))[-1]
                  args.MODEL_NAME = args.MODEL_NAME + latest_file

          這段代碼相當(dāng)于配置文件,來(lái)設(shè)置RWKV聊天系統(tǒng)的運(yùn)行參數(shù)和模型信息。根據(jù)選擇的語(yǔ)言不同,會(huì)加載相應(yīng)的模型,并設(shè)置相應(yīng)的參數(shù)。

                
                ########################################################################################################

          # 這行代碼用于打印一個(gè)字符串,包含了 CHAT_LANG、args.strategy 和 PROMPT_FILE 的值。
          # 它會(huì)在控制臺(tái)輸出當(dāng)前的語(yǔ)言、策略和提示文件的信息。
          print(f'\n{CHAT_LANG} - {args.strategy} - {PROMPT_FILE}')
          # 導(dǎo)入 RWKV 模型
          from rwkv.model import RWKV
          # 這行代碼導(dǎo)入了 PIPELINE 工具,用于處理模型的輸入和輸出。
          from rwkv.utils import PIPELINE

          # 用于加載提示文件的內(nèi)容并返回相應(yīng)的變量。
          def load_prompt(PROMPT_FILE):
              # 該函數(shù)首先創(chuàng)建了一個(gè)空的字典 variables,然后使用 open 函數(shù)打開(kāi) PROMPT_FILE 文件。
              variables = {}
              # 下來(lái),使用 exec 函數(shù)將文件內(nèi)容編譯并執(zhí)行,將結(jié)果存儲(chǔ)在 variables 字典中。
              with open(PROMPT_FILE, 'rb'as file:
                  exec(compile(file.read(), PROMPT_FILE, 'exec'), variables)
              # 然后,從 variables 字典中獲取了 user、bot、interface 和 init_prompt 的值。
              # init_prompt 被處理為一個(gè)列表,去除了多余的空格和換行符,并在開(kāi)頭和結(jié)尾添加了換行符。
              # 最后,函數(shù)返回獲取的變量。
              user, bot, interface, init_prompt = variables['user'], variables['bot'], variables['interface'], variables['init_prompt']
              init_prompt = init_prompt.strip().split('\n')
              for c in range(len(init_prompt)):
                  init_prompt[c] = init_prompt[c].strip().strip('\u3000').strip('\r')
              init_prompt = '\n' + ('\n'.join(init_prompt)).strip() + '\n\n'
              return user, bot, interface, init_prompt

          # Load Model

          # 這行代碼用于打印一個(gè)字符串,指示正在加載模型,并輸出 args.MODEL_NAME 的值。
          print(f'Loading model - {args.MODEL_NAME}')
          # 這行代碼創(chuàng)建了一個(gè) RWKV 模型對(duì)象,使用指定的模型路徑 args.MODEL_NAME 和策略 args.strategy。
          model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
          # 根據(jù) args.MODEL_NAME 的值,選擇了不同的分詞器和特殊標(biāo)記。
          # 如果模型路徑中包含 'world/' 或 '-World-',則使用 "rwkv_vocab_v20230424" 的分詞器,
          # 并設(shè)置了 END_OF_TEXT 和 END_OF_LINE 的值。
          # 否則,使用 current_path(當(dāng)前路徑)和 "20B_tokenizer.json" 的分詞器,
          # 并設(shè)置了 END_OF_TEXT、END_OF_LINE 和 END_OF_LINE_DOUBLE 的值。
          if 'world/' in args.MODEL_NAME or '-World-' in args.MODEL_NAME:
              pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
              END_OF_TEXT = 0
              END_OF_LINE = 11
          else:
              pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json")
              END_OF_TEXT = 0
              END_OF_LINE = 187
              END_OF_LINE_DOUBLE = 535
          # pipeline = PIPELINE(model, "cl100k_base")
          # END_OF_TEXT = 100257
          # END_OF_LINE = 198

          # 這行代碼創(chuàng)建了一個(gè)空的列表 model_tokens,用于存儲(chǔ)模型的輸入token。
          model_tokens = []
          model_state = None

          # 這行代碼創(chuàng)建了一個(gè)空的列表 AVOID_REPEAT_TOKENS,用于存儲(chǔ)避免重復(fù)的標(biāo)記。
          AVOID_REPEAT_TOKENS = []
          for i in AVOID_REPEAT:
              # 在循環(huán)內(nèi)部,將當(dāng)前元素 i 使用 pipeline.encode 函數(shù)進(jìn)行編碼,
              # 并將結(jié)果添加到 AVOID_REPEAT_TOKENS 列表中。
              # 這樣,AVOID_REPEAT_TOKENS 列表中存儲(chǔ)了避免重復(fù)的token的編碼。
              dd = pipeline.encode(i)
              assert len(dd) == 1
              AVOID_REPEAT_TOKENS += dd

          這段代碼的主要功能是加載和配置RWKV模型,并準(zhǔn)備生成文本所需的參數(shù)和工具。繼續(xù)解析:

                
                ########################################################################################################

          # 這是一個(gè)函數(shù)定義,用于以RNN模式運(yùn)行RWKV模型生成文本。
          def run_rnn(tokens, newline_adj = 0):
              # 這行代碼聲明在函數(shù)內(nèi)部使用全局變量 model_tokens 和 model_state。
              global model_tokens, model_state

              # 將輸入的token列表轉(zhuǎn)換為整數(shù)類型。
              tokens = [int(x) for x in tokens]
              # 將輸入的標(biāo)記列表添加到全局變量 model_tokens 中。
              model_tokens += tokens
              # print(f'### model ###\n{tokens}\n[{pipeline.decode(model_tokens)}]')

              # 當(dāng)token列表的長(zhǎng)度大于0時(shí),執(zhí)行以下操作:
              while len(tokens) > 0:
                  # 使用模型的前向傳播函數(shù) model.forward 對(duì)token列表的前
                  # CHUNK_LEN 個(gè)token進(jìn)行推理,并更新模型狀態(tài)。
                  out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
                  # 將token列表更新為剩余的token。
                  tokens = tokens[CHUNK_LEN:]

              # 將輸出概率向量中的換行符標(biāo)記 END_OF_LINE 的概率增加 newline_adj,用于調(diào)整換行的概率。
              out[END_OF_LINE] += newline_adj # adjust \n probability

              # 如果模型最后一個(gè)標(biāo)記在避免重復(fù)標(biāo)記列表 AVOID_REPEAT_TOKENS 中,執(zhí)行以下操作:
              if model_tokens[-1in AVOID_REPEAT_TOKENS:
                  # 將輸出概率向量中模型最后一個(gè)標(biāo)記對(duì)應(yīng)的概率設(shè)置為一個(gè)極小的值,用于避免模型生成重復(fù)的標(biāo)記。
                  out[model_tokens[-1]] = -999999999
              return out

          all_state = {}
          # 這是一個(gè)函數(shù)定義,用于保存模型狀態(tài)和token列表。
          def save_all_stat(srv, name, last_out):
              # 創(chuàng)建保存狀態(tài)的鍵名。
              n = f'{name}_{srv}'
              all_state[n] = {} # 創(chuàng)建一個(gè)空的字典,用于保存模型狀態(tài)和token列表。
              all_state[n]['out'] = last_out # 將最后的輸出概率向量保存到字典中。
              all_state[n]['rnn'] = copy.deepcopy(model_state) # 將深拷貝后的模型狀態(tài)保存到字典中。
              all_state[n]['token'] = copy.deepcopy(model_tokens) # 將深拷貝后的token列表保存到字典中。

          # 這是一個(gè)函數(shù)定義,用于加載保存的模型狀態(tài)和標(biāo)記列表。
          def load_all_stat(srv, name):
              # 這行代碼聲明在函數(shù)內(nèi)部使用全局變量 model_tokens 和 model_state。
              global model_tokens, model_state
              # 獲取保存狀態(tài)的鍵名。
              n = f'{name}_{srv}'
              # 將保存的模型狀態(tài)深拷貝給全局變量 model_state。
              model_state = copy.deepcopy(all_state[n]['rnn'])
              # 將保存的token列表深拷貝給全局變量 model_tokens。
              model_tokens = copy.deepcopy(all_state[n]['token'])
              return all_state[n]['out']

          # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
          # 這是一個(gè)函數(shù)定義,用于修復(fù)token列表。
          def fix_tokens(tokens):
              # 根據(jù)模型路徑是否包含 'world/' 或 '-World-',執(zhí)行以下操作:
              if 'world/' in args.MODEL_NAME or '-World-' in args.MODEL_NAME:
                  # 如果是,則返回原始的標(biāo)記列表 tokens,無(wú)需修復(fù)。
                  return tokens
              # 如果不是,則檢查標(biāo)記列表的最后一個(gè)標(biāo)記是否為 END_OF_LINE_DOUBLE,
              # 如果是,則將標(biāo)記列表中的最后一個(gè)標(biāo)記替換為 END_OF_LINE 重復(fù)兩次。
              if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
                  tokens = tokens[:-1] + [END_OF_LINE, END_OF_LINE]
              return tokens

          ########################################################################################################

          這段代碼定義了以RNN模式運(yùn)行RWKV的函數(shù)以及保存和恢復(fù)模型狀態(tài)的函數(shù),最后還定義了一個(gè)修復(fù)tokens的工具函數(shù)用于對(duì)world模型和其它模型分別進(jìn)行處理,繼續(xù)解析:

                
                ########################################################################################################

          # Run inference
          print(f'\nRun prompt...'# 打印提示信息 "Run prompt..."。

          # 調(diào)用 load_prompt 函數(shù)加載提示文件,將返回的用戶、機(jī)器人、界面和初始提示內(nèi)容
          # 分別賦值給變量 user、bot、interface 和 init_prompt。
          user, bot, interface, init_prompt = load_prompt(PROMPT_FILE)
          # 調(diào)用 fix_tokens 函數(shù)修復(fù)初始提示內(nèi)容的標(biāo)記列表,并將修復(fù)后的標(biāo)記列表傳遞給 run_rnn 函數(shù)進(jìn)行推理。
          # 將生成的輸出概率向量賦值給變量 out。
          out = run_rnn(fix_tokens(pipeline.encode(init_prompt)))
          # 調(diào)用 save_all_stat 函數(shù)保存模型狀態(tài)和標(biāo)記列表,鍵名為 'chat_init',值為 out。
          save_all_stat('''chat_init', out)
          # 執(zhí)行垃圾回收和清空GPU緩存的操作。
          gc.collect()
          torch.cuda.empty_cache()

          # 創(chuàng)建一個(gè)服務(wù)器列表 srv_list,其中包含一個(gè)名為 'dummy_server' 的服務(wù)器。
          srv_list = ['dummy_server']
          # 遍歷服務(wù)器列表,對(duì)于每個(gè)服務(wù)器,調(diào)用 save_all_stat 函數(shù)保存模型狀態(tài)和token列表,
          # 鍵名包含服務(wù)器名和 'chat',值為 out。
          for s in srv_list:
              save_all_stat(s, 'chat', out)

          # 定義一個(gè)名為 reply_msg 的函數(shù),該函數(shù)接受一個(gè)參數(shù) msg,并打印機(jī)器人、界面和回復(fù)消息。
          def reply_msg(msg):
              print(f'{bot}{interface} {msg}\n')

          # 這段代碼定義了一個(gè)名為 on_message 的函數(shù),用于處理接收到的消息。
          def on_message(message):
              # 聲明在函數(shù)內(nèi)部使用全局變量 model_tokens、model_state、user、bot、interface 和 init_prompt。
              global model_tokens, model_state, user, bot, interface, init_prompt

              # 將字符串 'dummy_server' 賦值給變量 srv。
              srv = 'dummy_server'

              # 將接收到的消息中的轉(zhuǎn)義字符 '\\n' 替換為換行符 '\n',并去除首尾空白字符,將結(jié)果賦值給變量 msg。
              msg = message.replace('\\n','\n').strip()

              x_temp = GEN_TEMP
              x_top_p = GEN_TOP_P
              # 如果消息中包含 -temp=,執(zhí)行以下操作:
              # a. 從消息中提取 -temp= 后面的值,并將其轉(zhuǎn)換為浮點(diǎn)數(shù)類型賦值給 x_temp。
              # 從消息中移除 -temp= 部分。
              if ("-temp=" in msg):
                  x_temp = float(msg.split("-temp=")[1].split(" ")[0])
                  msg = msg.replace("-temp="+f'{x_temp:g}'"")
                  # print(f"temp: {x_temp}")
              if ("-top_p=" in msg):
                  x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
                  msg = msg.replace("-top_p="+f'{x_top_p:g}'"")
                  # print(f"top_p: {x_top_p}")
              # 如果 x_temp 小于等于 0.2,將其設(shè)置為 0.2。
              if x_temp <= 0.2:
                  x_temp = 0.2
              # 如果 x_temp 大于等于 5,將其設(shè)置為 5。
              if x_temp >= 5:
                  x_temp = 5
              # 如果 x_top_p 小于等于 0,將其設(shè)置為 0。
              if x_top_p <= 0:
                  x_top_p = 0
              # 去除消息首尾空白字符。
              msg = msg.strip()
              
              # 如果消息等于 '+reset':
              if msg == '+reset':
                  # 加載保存的初始模型狀態(tài)并將其保存為 out。
                  out = load_all_stat('''chat_init')
                  # 保存模型狀態(tài)和token列表。
                  save_all_stat(srv, 'chat', out)
                  # 調(diào)用 reply_msg 函數(shù)打印回復(fù)消息 "Chat reset."。
                  reply_msg("Chat reset.")
                  return
              
              # use '+prompt {path}' to load a new prompt
              # 如果消息以 '+prompt ' 開(kāi)頭
              elif msg[:8].lower() == '+prompt ':
                  # 打印 "Loading prompt..."。
                  print("Loading prompt...")
                  try:
                      # 提取消息中 'prompt ' 后面的內(nèi)容作為提示文件路徑,并將其賦值給變量 PROMPT_FILE。
                      PROMPT_FILE = msg[8:].strip()
                      # 加載提示文件,將返回的用戶、機(jī)器人、界面和初始提示內(nèi)容分別賦值給變量
                      # user、bot、interface 和 init_prompt。
                      user, bot, interface, init_prompt = load_prompt(PROMPT_FILE)
                      # 對(duì)prompt編碼和推理
                      out = run_rnn(fix_tokens(pipeline.encode(init_prompt)))
                      # 保存模型狀態(tài)和token列表。
                      save_all_stat(srv, 'chat', out)
                      # 打印 "Prompt set up."。
                      print("Prompt set up.")
                      gc.collect()
                      torch.cuda.empty_cache()
                  except:
                      # 捕獲異常,打印 "Path error."。
                      print("Path error.")

              # 如果消息以 '+gen '、'+i '、'+qa '、'+qq '、'+++' 或 '++' 開(kāi)頭,執(zhí)行以下操作:
              elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':

                  if msg[:5].lower() == '+gen ':
                      # 提取消息中 'gen ' 后面的內(nèi)容作為新的提示內(nèi)容,并將其賦值給變量 new。
                      new = '\n' + msg[5:].strip()
                      # print(f'### prompt ###\n[{new}]')
                      # 將模型狀態(tài)和標(biāo)記列表重置為空。
                      model_state = None
                      model_tokens = []
                      # 運(yùn)行RNN模型進(jìn)行推理,生成回復(fù)的輸出概率向量,并將其保存為 out。
                      out = run_rnn(pipeline.encode(new))
                      # 保存模型狀態(tài)和token列表。
                      save_all_stat(srv, 'gen_0', out)

                  elif msg[:3].lower() == '+i ':
                      # 提取消息中 'i ' 后面的內(nèi)容作為新的指令,并將其賦值給變量 msg。
                      msg = msg[3:].strip().replace('\r\n','\n').replace('\n\n','\n')
                      # 替換指令中的換行符,將 '\r\n' 替換為 '\n',將連續(xù)的兩個(gè)換行符 '\n\n' 替換為單個(gè)換行符 '\n'。
                      # 構(gòu)建新的提示內(nèi)容 new,包括指令和響應(yīng)模板。
                      new = f'''
          Below is an instruction that describes a task. Write a response that appropriately completes the request.

          # Instruction:
          {msg}

          # Response:
          '''

                      # print(f'### prompt ###\n[{new}]')
                      # 將模型狀態(tài)和token列表重置為空。
                      model_state = None
                      model_tokens = []
                      # 運(yùn)行RNN模型進(jìn)行推理,生成回復(fù)的輸出概率向量,并將其保存為 out。
                      out = run_rnn(pipeline.encode(new))
                      # 保存模型狀態(tài)和token列表。
                      save_all_stat(srv, 'gen_0', out)

                  elif msg[:4].lower() == '+qq ':
                      # 提取消息中 'qq ' 后面的內(nèi)容作為新的問(wèn)題,構(gòu)建新的提示內(nèi)容 new,包括問(wèn)題和回答模板。
                      new = '\nQ: ' + msg[4:].strip() + '\nA:'
                      # print(f'### prompt ###\n[{new}]')
                      # 將模型狀態(tài)和token列表重置為空。
                      model_state = None
                      model_tokens = []
                      # 運(yùn)行RNN模型進(jìn)行推理,生成回復(fù)的輸出概率向量,并將其保存為 out。
                      out = run_rnn(pipeline.encode(new))
                      # 保存模型狀態(tài)和token列表。
                      save_all_stat(srv, 'gen_0', out)

                  elif msg[:4].lower() == '+qa ':
                      # 加載保存的初始模型狀態(tài)并將其保存為 out。
                      out = load_all_stat('''chat_init')

                      # 提取消息中 'qa ' 后面的內(nèi)容作為真實(shí)消息,并將其賦值給變量 real_msg。
                      real_msg = msg[4:].strip()
                      # 構(gòu)建新的提示內(nèi)容 new,包括用戶、界面、真實(shí)消息和機(jī)器人。
                      new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
                      # print(f'### qa ###\n[{new}]')
                      
                      # 運(yùn)行RNN模型進(jìn)行推理,生成回復(fù)的輸出概率向量,并將其保存為 out。
                      out = run_rnn(pipeline.encode(new))
                      # 保存模型狀態(tài)和token列表。
                      save_all_stat(srv, 'gen_0', out)

                  elif msg.lower() == '+++':
                      try:
                          # 加載保存的模型狀態(tài) gen_1 并將其保存為 out。
                          out = load_all_stat(srv, 'gen_1')
                          # 保存模型狀態(tài)和token列表。
                          save_all_stat(srv, 'gen_0', out)
                      except:
                          return

                  elif msg.lower() == '++':
                      try:
                          # 加載保存的模型狀態(tài) gen_0 并將其保存為 out。
                          out = load_all_stat(srv, 'gen_0')
                      except:
                          return

                  # 將變量 begin 設(shè)置為 model_tokens 的長(zhǎng)度。
                  begin = len(model_tokens)
                  # 將變量 out_last 設(shè)置為 begin。
                  out_last = begin
                  # 創(chuàng)建一個(gè)空字典 occurrence。
                  occurrence = {}
                  # 循環(huán) FREE_GEN_LEN+100 次,其中 FREE_GEN_LEN 是一個(gè)常量,代表自由生成的長(zhǎng)度。
                  for i in range(FREE_GEN_LEN+100):
                      # 遍歷字典 occurrence 中的鍵。
                      for n in occurrence:
                          # 將 out[n] 減去一個(gè)計(jì)算得到的重復(fù)懲罰項(xiàng)。
                          out[n] -= (GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency)
                      # 使用 pipeline.sample_logits 函數(shù)根據(jù)概率向量 out 生成一個(gè)標(biāo)記,并將其賦值給 token。
                      token = pipeline.sample_logits(
                          out,
                          temperature=x_temp,
                          top_p=x_top_p,
                      )
                      # 如果 token 等于 END_OF_TEXT,跳出內(nèi)層循環(huán)。
                      if token == END_OF_TEXT:
                          break
                      # 遍歷字典 occurrence 中的鍵。將 occurrence[xxx] 乘以一個(gè)常量 GEN_penalty_decay
                      for xxx in occurrence:
                          occurrence[xxx] *= GEN_penalty_decay
                      # 如果 token 不在 occurrence 中,將 occurrence[token] 設(shè)置為 1。
                      if token not in occurrence:
                          occurrence[token] = 1
                      else:
                      # 否則+1,表示這個(gè)token重復(fù)次數(shù)+1
                          occurrence[token] += 1

                      # 如果 msg[:4].lower() == '+qa ',調(diào)用 run_rnn 函數(shù),
                      # 傳遞 [token] 作為參數(shù),并將返回值賦值給變量 out。
                      if msg[:4].lower() == '+qa ':# or msg[:4].lower() == '+qq ':
                          out = run_rnn([token], newline_adj=-2)
                      else:
                          # 調(diào)用 run_rnn 函數(shù),傳遞 [token] 作為參數(shù),并將返回值賦值給變量 out。
                          out = run_rnn([token])
                      
                      # 使用 pipeline.decode 函數(shù)將 model_tokens[out_last:] 解碼為字符串,并將結(jié)果賦值給 xxx。
                      xxx = pipeline.decode(model_tokens[out_last:])
                      # 如果字符串 '\ufffd' 不在 xxx 中
                      if '\ufffd' not in xxx: # avoid utf-8 display issues
                          # 打印 xxx,并刷新輸出緩沖區(qū)。
                          print(xxx, end='', flush=True)
                          # 將 out_last 設(shè)置為 begin + i + 1。
                          out_last = begin + i + 1
                          # 如果 i 大于等于 FREE_GEN_LEN,跳出外層循環(huán)。
                          if i >= FREE_GEN_LEN:
                              break
                  print('\n')
                  # send_msg = pipeline.decode(model_tokens[begin:]).strip()
                  # print(f'### send ###\n[{send_msg}]')
                  # reply_msg(send_msg)
                  # 調(diào)用 save_all_stat 函數(shù),將參數(shù) srv、'gen_1' 和 out 傳遞給它。
                  save_all_stat(srv, 'gen_1', out)

              else:
                  # 如果 msg.lower() == '+'
                  if msg.lower() == '+':
                      try:
                          # 嘗試加載狀態(tài)信息 load_all_stat(srv, 'chat_pre')。
                          out = load_all_stat(srv, 'chat_pre')
                      except:
                          return
                  else:
                      # 加載狀態(tài)信息 load_all_stat(srv, 'chat'),并將結(jié)果賦值給變量 out。
                      out = load_all_stat(srv, 'chat')
                      # 對(duì)消息 msg 進(jìn)行處理,去除首尾空格,替換換行符,
                      msg = msg.strip().replace('\r\n','\n').replace('\n\n','\n')
                      # 并構(gòu)造新的消息字符串 new,其中包括用戶和機(jī)器人的標(biāo)識(shí)符。
                      new = f"{user}{interface} {msg}\n\n{bot}{interface}"
                      # print(f'### add ###\n[{new}]')
                      # 調(diào)用 run_rnn 函數(shù),傳遞 pipeline.encode(new) 作為參數(shù),并將返回值賦值給變量 out。
                      out = run_rnn(pipeline.encode(new), newline_adj=-999999999)
                      # 將生成的狀態(tài)信息 out 保存為 'chat_pre' 的狀態(tài)信息。
                      save_all_stat(srv, 'chat_pre', out)

                  # 這里開(kāi)始的內(nèi)容和上一個(gè)elif里面的基本一致,就不重復(fù)解析了
                  begin = len(model_tokens)
                  out_last = begin
                  print(f'{bot}{interface}', end='', flush=True)
                  occurrence = {}
                  for i in range(999):
                      if i <= 0:
                          newline_adj = -999999999
                      elif i <= CHAT_LEN_SHORT:
                          newline_adj = (i - CHAT_LEN_SHORT) / 10
                      elif i <= CHAT_LEN_LONG:
                          newline_adj = 0
                      else:
                          newline_adj = min(3, (i - CHAT_LEN_LONG) * 0.25# MUST END THE GENERATION

                      for n in occurrence:
                          out[n] -= (GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency)
                      token = pipeline.sample_logits(
                          out,
                          temperature=x_temp,
                          top_p=x_top_p,
                      )
                      # if token == END_OF_TEXT:
                      #     break
                      for xxx in occurrence:
                          occurrence[xxx] *= GEN_penalty_decay            
                      if token not in occurrence:
                          occurrence[token] = 1
                      else:
                          occurrence[token] += 1
                      
                      out = run_rnn([token], newline_adj=newline_adj)
                      out[END_OF_TEXT] = -999999999  # disable <|endoftext|>

                      xxx = pipeline.decode(model_tokens[out_last:])
                      if '\ufffd' not in xxx: # avoid utf-8 display issues
                          print(xxx, end='', flush=True)
                          out_last = begin + i + 1
                      
                      send_msg = pipeline.decode(model_tokens[begin:])
                      if '\n\n' in send_msg:
                          send_msg = send_msg.strip()
                          break
                      
                      # send_msg = pipeline.decode(model_tokens[begin:]).strip()
                      # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
                      #     send_msg = send_msg[:-len(f'{user}{interface}')].strip()
                      #     break
                      # if send_msg.endswith(f'{bot}{interface}'):
                      #     send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
                      #     break

                  # print(f'{model_tokens}')
                  # print(f'[{pipeline.decode(model_tokens)}]')

                  # print(f'### send ###\n[{send_msg}]')
                  # reply_msg(send_msg)
                  save_all_stat(srv, 'chat', out)

          ########################################################################################################

          總的來(lái)看,這段代碼是一個(gè)循環(huán),用于ChatRWKV系統(tǒng)生成回復(fù)消息和更新?tīng)顟B(tài)信息。它根據(jù)輸入消息和模型的狀態(tài)信息進(jìn)行一個(gè)RNN模式的推理,生成一個(gè)回復(fù)的token和一個(gè)新的狀態(tài),然后將生成的回復(fù)顯示出來(lái),生成的狀態(tài)則可以在下一次生成中繼續(xù)使用。代碼還包括處理特殊命令以及加載和保存狀態(tài)信息的邏輯。

          0x3. ChatRWKV v2聊天系統(tǒng)指南
          • 直接說(shuō)話:聊天,用 + 換個(gè)回答
          • reset: 通過(guò)發(fā)送+reset消息,您可以重置聊天。
          • 加載新的提示: 使用+prompt {path}可以加載一個(gè)新的提示,其中{path}是提示文件的路徑。
          • 消息類型:
            • +gen {text}: 基于{text}生成新的內(nèi)容。
            • +i {instruction}: 根據(jù)給定的指令{instruction}產(chǎn)生一個(gè)響應(yīng)。
            • +qq {question}: 對(duì)于給定的問(wèn)題{question},生成一個(gè)答案。
            • +qa {text}: 將{text}作為一個(gè)問(wèn)題,并生成一個(gè)答案。
            • +++: 繼續(xù)寫下去。
            • ++: 換個(gè)寫法。

          除了這些指令之外,還可以調(diào)整生成參數(shù):

          • -temp=: 調(diào)整生成的溫度。溫度影響生成文本的隨機(jī)性。較高的溫度將導(dǎo)致更多的隨機(jī)輸出,而較低的溫度將導(dǎo)致更確定的輸出。
          • -top_p=: 調(diào)整Top-P采樣。Top-P采樣是選擇詞匯中概率最高的一部分詞進(jìn)行采樣的方法。
          0x4. 總結(jié)

          這篇文章還有一些ChatRWKV v2系統(tǒng)的模型實(shí)現(xiàn)部分,tokenizer部分都沒(méi)有解析到,但目前篇幅已經(jīng)比較多了,希望留到下次解析。enjoy ChatRWKV v2!

          瀏覽 275
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  三级国产视频 | 国产系列第一页在线观看 | 小早川怜子 无码 在线 | 大香蕉情色 | 三级黄瓜视频 |