当前位置: 首页 > news >正文

小网站源码有没有免费的虚拟主机

小网站源码,有没有免费的虚拟主机,多作者wordpress插件,广州市建筑信息平台目录 一、推测解码speculative decoding 1、自回归解码 2、speculative decoding 3、细节理解 二、核心逻辑代码 1、算法流程代码 2、模型自回归代码 a、带缓存的模型自回归实现代码 b、优化版本带缓存的模型自回归实现代码 c、ChatGLM的past_key_values的回滚 三、…        目录 一、推测解码speculative decoding 1、自回归解码 2、speculative decoding 3、细节理解 二、核心逻辑代码 1、算法流程代码 2、模型自回归代码 a、带缓存的模型自回归实现代码 b、优化版本带缓存的模型自回归实现代码 c、ChatGLM的past_key_values的回滚 三、效果实测 1、效果对比 2、解码日志展示 大模型时代模型的推理效率尤为重要推理速度的快慢和模型生成的质量好坏对用户的体验影响很大。大模型生成速度慢生成效果好小模型推理速度快但是推理质量稍差。当前大模型推理速度满不足不了业务实效性需求小模型不能满足业务质量指标的情况下存不存在一种业务在实际落地的时候最优选择呢google论文Fast Inference from Transformers via Speculative Decoding和deepmind论文Accelerating Large Language Model Decoding with Speculative Sampling给出了相同思路的解决方案也就是这篇博客要谈到的东西Speculative Decoding翻译为推测解码。简单的来说推测解码一种联合小模型和大模型各自的优势的解码算法用大模型来对小模型的生成结果进行评估和修正以实现完美保留大模型的生成质量的同时提升推理速度的目标。 一、推测解码speculative decoding 1、自回归解码 首先看看自回归解码的流程如图所示 1、把输入序列输入模型把输出结果logits处理成概率分布采样得到next token。 2、把新生成的next token和原来的输入序列拼接在一起作为新的输入序列然后重复第一步。直到结束——新生成的token是结尾标识符EOS或者达到最大限定长度。 注意其中1中得到next token可以采用greedy search、beam search、topK和topP sample 以及contrastive search等算法。 2、speculative decoding 推测解码本质也也是自回归解码大体流程上和上述流程相似只不过在推理过程才使用了一大一小的模型对效率和质量做了权衡当然也是以牺牲显存和两次模型训练为代价的。用google论文中的示意图来举例说明 小模型先生成japan’s benchmark bond n大模型对这些文本进行评估发现bond生成不好拒绝该生成丢弃掉bond n 这部分生成修正生成为 nikkei 然后把japan’s benchmark nikkei 拼接起来由小模型继续生成直到结束位置。 具体的流程是什么样的呢看deepmind论文中的算法伪代码 根据上述算法伪代码可以得出Speculative Sampling的具体流程 1、设置小模型每次自回归解码的步数K以及最大解码输出长度T。 2、首先用小模型Mq做自回归生成连续生成采样得到gama个tokens并同时得到每个token的概率分布q_logits。 3、然后用前缀输入prefix和这个gama个tokens拼接在一起输入大模型Mp执行forward得到每个token的概率分布p_logits。 4、使用大小模型的logits做对比如果大模型发现小模型的某个token不好拒绝它使用某种办法重新采样这个token如果小模型生成的token都很好大模型直接生成下一个token 5、然后重复2-4步骤直到结束条件。 3、细节理解 算法中有几个细节需要注意 1、论文中并行的计算得到小模型生成的gama个tokens的大模型对应的logits 其实这里的理解不要太复杂了就是等待小模型生成完了gama个tokens以后直接用大模型进行forward这里就利用了矩阵计算并行的得到了大模型对应的logits。不需要考虑的复杂小模型生成一个token1号大模型就计算一次小模型再生成一个2号大模型再计算该token的logits。这样极其浪费资源并且并没有一次forward那么快。 2、怎么来评价小模型生成的token好不好 利用概率分布来评价这里有个前提大模型的质量比较好生成的结果都是正确的。大模型下该token概率为p小模型下为q如果pq说明大模型都认为这个位置该生成这个token那么就接受这个token反之则以1-p/q的概率拒绝这个token——大模型的概率越小就越要拒绝越大就约不拒绝。 3、拒绝该token后怎么生成新的token 论文中给出了一个方案使用如上公式用大模型的logtis和小模型的logits进行计算得到一个新的概率分布——最后采样的结果就是大模型和小模型概率差异化越大就越能被采样到。 以上就是原始论文中的核心思想和重要的地方但是在代码实现的时候还是有一些细节的使用缓存进行加速推理、logits归一化的加速、采样的方案选择等等对后续算法的效果和效率都有一定的影响。 二、核心逻辑代码 首先声明该本篇博客中实现的算法代码借鉴了方佳瑞博士大模型推理妙招—投机采样Speculative Decoding的实现整体思路局部细节进行了优化并使用ChatGLM基座模型进行了效果实测证实优化后生成质量和推理速度均有提升。 1、算法流程代码 def speculative_decode_generate(self, prompt, gamma, do_sampleFalse, debugFalse, promotionTrue):eos_token self.tokenizer.eos_token_id# 小模型生成类初始化approx_model_kv_cached_genration KVCachedGenration(self.approx_model, eos_token, self.temperature, self.top_k,self.top_p, do_sample)# 大模型生成类初始化target_model_kv_cached_genration KVCachedGenration(self.target_model, eos_token, self.temperature, self.top_k,self.top_p, do_sample)inputs self.tokenizer([prompt], return_tensorspt, paddingTrue)prefix inputs[input_ids].to(self.device)seq_len prefix.shape[1]T seq_len self.max_lenend Tassert prefix.shape[0] 1, input batch size must be 1eos_token_index Nonewith torch.no_grad():while prefix.shape[1] T:prefix_len prefix.shape[1]# 得到小模型的生成结果x, index approx_model_kv_cached_genration.generate_promotion(prefix, gamma, debug)#记录小模型生成的eos_token位置if index is not None:eos_token_index prefix_len index# 得到大模型的生成结果t_x, _ target_model_kv_cached_genration.generate_promotion(x, 1, debug)n prefix_len gamma - 1# 大模型评价小模型效果for i in range(gamma):if self.random_seed:torch.manual_seed(self.random_seed)# r1的时候表示只要大模型的logits比小模型的小就拒绝生成r torch.rand(1, deviceself.device)else:r torch.as_tensor(1.0, dtypetorch.float, deviceself.device)j x[:, prefix_len i]# 拒接 以1-p/q 的概率拒绝 生成if r (target_model_kv_cached_genration.prob_history[:, prefix_len i - 1,j] / approx_model_kv_cached_genration.prob_history[:, prefix_len i - 1, j]):n prefix_len i - 1break# 根据小模型采纳的token位置来回滚past_key_values缓存approx_model_kv_cached_genration.kv_rollback(n 1)prefix x[:, :n 1]if eos_token_index is not None:# 如果小模型生成的eos不被接受则继续生成if eos_token_index n:eos_token_index Noneelse:# 如果小模型生成的eos被接受则停止break# 如果拒绝一些小模型生成的token从大模型中采样if n prefix_len gamma - 1:if do_sample:t sample(max_fn(target_model_kv_cached_genration.prob_history[:, n,:] - approx_model_kv_cached_genration.prob_history[:, n, :]))else:# t torch.argmax(max_fn(target_model_kv_cached_genration.prob_history[:, n,# :] - approx_model_kv_cached_genration.prob_history[:, n, :]),# dim-1).unsqueeze(0)# 直接取大模型logits最大的那个tokent torch.argmax(target_model_kv_cached_genration.prob_history[:, n, :], dim-1).unsqueeze(0)# 根据生成的token位置大模型回滚past_key_values缓存target_model_kv_cached_genration.kv_rollback(n 1)else:if do_sample:t sample(target_model_kv_cached_genration.prob_history[:, -1, :])else:t torch.argmax(target_model_kv_cached_genration.prob_history[:, -1, :], dim-1).unsqueeze(0)# 根据生成的token位置大模型回滚past_key_values缓存target_model_kv_cached_genration.kv_rollback(n 2)prefix torch.cat((prefix, t), dim1)# 如果生成eos token 则停止生成if t eos_token:eos_token_index prefix.shape[1] - 1breakif eos_token_index:end eos_token_index 1input_ids inputs[input_ids].tolist()[0]assert len(input_ids) end, eos_token选择位置错误output_token_ids prefix.tolist()[0][len(input_ids):end]#得到最终输出response self.tokenizer.decode(output_token_ids)代码中实现了上述推测解码的整个流程并给出了退出解码过程的条件。也实现了不同采样token的方式以及不同拒绝小模型token的策略对实际结果也是很有影响的。 2、模型自回归代码 这里其实也比较细节的实现的好不好会影响模型最终生成的效率的加不加缓存以及向量归一化都会极大影响模型推理速度。 a、带缓存的模型自回归实现代码 torch.no_grad() def generate(self, input_ids, gamma):eos_index Nonex input_idsfor index, _ in enumerate(range(gamma)):# past_key_values存在非首次推理if self.past_key_values:kv_cached_len self.past_key_values[0][0].shape[0]# 获取位置idsposition_ids self.get_position_ids(x, x.device)position_ids position_ids[..., kv_cached_len:]# ChatGLM一定要输入位置编码一般推理的时候模型内部会自己配置位置编码的outputs self.model(x[:, kv_cached_len:], position_idsposition_ids,past_key_valuesself.past_key_values, use_cacheTrue)logits outputs.logits# logits归一化for i in range(logits.shape[1]):logits[:, i, :] norm_logits(logits[:, i, :], self.temperature, self.top_k, self.top_p)probs logits[:, -1, :]self.past_key_values outputs.past_key_valuesself.prob_history torch.cat([self.prob_history, logits], dim1)else:# past_key_values不存在首次推理outputs self.model(x)self.prob_history outputs.logits# logits归一化for i in range(self.prob_history.shape[1]):self.prob_history[:, i, :] norm_logits(self.prob_history[:, i, :], self.temperature, self.top_k,self.top_p)self.past_key_values outputs.past_key_valuesprobs self.prob_history[:, -1, :]# 采样得到next_tokif self.do_sample:next_tok sample(probs)else:next_tok torch.argmax(probs, dim-1).unsqueeze(0)x torch.cat((x, next_tok), dim1)if next_tok self.eos_token and eos_index is None:eos_index indexreturn x, eos_index# copy from https://github.com/LeeSinLiang/microGPT/blob/ed40cf9780dbeb180adfe94c227d4aa97e69250e/gpt.py def top_k_top_p_filter(logits: torch.Tensor, top_k: int 0, top_p: float 0.0):Args:logits (torch.Tensorpe_): 2D tensor with shape (batch, vocab)top_k (int, optional): top_k. Defaults to 0.top_p (float, optional): top_p. Defaults to 0.0.Returns:torch.Tensor: a renormalized logitsif top_k 0:filter torch.topk(logits, min(top_k, logits.size(-1)))[0]logits[logits filter[:, [-1]]] float(-inf)if top_p 0.0:sorted_logits, sorted_indices torch.sort(logits, descendingTrue)cumulative_probs torch.cumsum(F.softmax(sorted_logits, dim-1), dim-1)filter cumulative_probs top_pfilter[..., 1:] filter[..., :-1].clone()filter[..., 0] 0indices_to_remove filter.scatter(1, sorted_indices, filter)logits[indices_to_remove] float(-inf)return logitsdef norm_logits(logits: torch.Tensor, temperature: float, top_k: float, top_p: float) - torch.Tensor:Args:logits (torch.Tensor): shape (1, vocab)temperature (float): temperaturetop_k (float): top_ktop_p (float): top_pReturns:torch.Tensor: next token with shape as (batch, 1)assert logits.dim() 2logits logits / temperaturelogits top_k_top_p_filter(logits, top_ktop_k, top_ptop_p)probs F.softmax(logits, dim1)return probsdef sample(probs: torch.Tensor, num_samples: int 1):idx_next torch.multinomial(probs, num_samplesnum_samples)if (idx_next.item() 0):raise RuntimeErrorreturn idx_next 正常的自回归流程每次生成的时候都会传入past_key_values(这个时候模型输入就是x[:, kv_cached_len:])进行加速同时输入position_ids保证结果正确。得到logits后进行归一化(为了后面能够对比大小模型的logits大模型能够评价小模型必须scale到同一个值域里面)、topK、topP以及softmax处理然后采样得到下一个token。以上是没有优化过的版本在归一化、topK、topP以及softmax处理过程中大小模型首次推理以及大模型非首次推理都是多个token代码里采用for循环的方式实现耗时太多可以矩阵计算并行化。 b、优化版本带缓存的模型自回归实现代码 并不是优化自回归这个算法本身是优化自回归过程中大小模型首次推理以及大模型非首次推理logits归一化的实现。这里采用并行计算也就是矩阵计算。 def norm_logits_whole(logits: torch.Tensor, temperature: float, top_k: float, top_p: float) - torch.Tensor:# 优化一版提速 矩阵计算代替原来的for循环assert logits.dim() 3Args:logits (torch.Tensor): shape (1, vocab)temperature (float): temperaturetop_k (float): top_ktop_p (float): top_pReturns:torch.Tensor: next token with shape as (batch, 1)def top(scores: torch.FloatTensor, top_k: int 0, top_p: float 0.0):sorted_logits, sorted_indices torch.sort(scores, descendingFalse)cumulative_probs sorted_logits.softmax(dim-1)cumulative_probs cumulative_probs.cumsum(dim-1)sorted_indices_to_remove cumulative_probs (1 - top_p)sorted_indices_to_remove[..., -1:] 0# 主要是这句代码优化indices_to_remove sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)scores scores.masked_fill(indices_to_remove, -float(Inf))return scoresif top_k 0:filter torch.topk(logits, min(2, logits.size(-1)))[0]logits[logits filter[..., [-1]]] float(-inf)logits logits / temperaturelogits top(logits, top_k, top_p)probs F.softmax(logits, dim2)return probstorch.no_grad() def generate_promotion(self, input_ids, gamma, debugFalse):eos_index Nonex input_idsfor index, _ in enumerate(range(gamma)):if self.past_key_values:kv_cached_len self.past_key_values[0][0].shape[0]position_ids self.get_position_ids(x, x.device)position_ids position_ids[..., kv_cached_len:]# ChatGLM一定要输入位置编码一般推理的时候模型内部会自己配置位置编码的outputs self.model(x[:, kv_cached_len:], position_idsposition_ids,past_key_valuesself.past_key_values, use_cacheTrue)logits outputs.logits# 优化提速版本归一化logits norm_logits_whole(logits, self.temperature, self.top_k, self.top_p)probs logits[:, -1, :]self.past_key_values outputs.past_key_valuesself.prob_history torch.cat([self.prob_history, logits], dim1)else:outputs self.model(x)self.prob_history outputs.logits# 优化提速版本归一化self.prob_history norm_logits_whole(self.prob_history, self.temperature, self.top_k, self.top_p)self.past_key_values outputs.past_key_valuesprobs self.prob_history[:, -1, :]if self.do_sample:next_tok sample(probs)else:next_tok torch.argmax(probs, dim-1).unsqueeze(0)x torch.cat((x, next_tok), dim1)if next_tok self.eos_token and eos_index is None:eos_index indexreturn x, eos_indexdef sample(probs: torch.Tensor, num_samples: int 1):idx_next torch.multinomial(probs, num_samplesnum_samples)if (idx_next.item() 0):raise RuntimeErrorreturn idx_next c、ChatGLM的past_key_values的回滚 torch.no_grad() def kv_rollback(self, end_index):past_key_values_keeps []for index, kv in enumerate(self.past_key_values):k, v kvk k[:end_index, :, :, :]v v[:end_index, :, :, :]kv (k, v)past_key_values_keeps.append(kv)self.past_key_values past_key_values_keepsself.prob_history self.prob_history[:, :end_index, :] 注意past_key_values不同的模型的维度不一样llama k, v  (batch, num_head, seq_len, hidden_dim)ChatGLM k, v (seq, batch, head, hidden_dim)而Bloom  k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim)在实现的时候需要注意。 比较重要的代码笔记都介绍完毕了没有写文件的依赖和一些不重要的逻辑放一下我这边的一个python文件speculate_decoding.py的全部代码——实现了不同的采样方式、gama动态输入、是否debug等 import torch from torch.nn import functional as F import timetorch.set_printoptions(precision10)# copy from https://github.com/LeeSinLiang/microGPT/blob/ed40cf9780dbeb180adfe94c227d4aa97e69250e/gpt.py def top_k_top_p_filter(logits: torch.Tensor, top_k: int 0, top_p: float 0.0):Args:logits (torch.Tensorpe_): 2D tensor with shape (batch, vocab)top_k (int, optional): top_k. Defaults to 0.top_p (float, optional): top_p. Defaults to 0.0.Returns:torch.Tensor: a renormalized logitsif top_k 0:filter torch.topk(logits, min(top_k, logits.size(-1)))[0]logits[logits filter[:, [-1]]] float(-inf)if top_p 0.0:sorted_logits, sorted_indices torch.sort(logits, descendingTrue)cumulative_probs torch.cumsum(F.softmax(sorted_logits, dim-1), dim-1)filter cumulative_probs top_pfilter[..., 1:] filter[..., :-1].clone()filter[..., 0] 0indices_to_remove filter.scatter(1, sorted_indices, filter)logits[indices_to_remove] float(-inf)return logitsdef norm_logits(logits: torch.Tensor, temperature: float, top_k: float, top_p: float) - torch.Tensor:Args:logits (torch.Tensor): shape (1, vocab)temperature (float): temperaturetop_k (float): top_ktop_p (float): top_pReturns:torch.Tensor: next token with shape as (batch, 1)assert logits.dim() 2logits logits / temperaturelogits top_k_top_p_filter(logits, top_ktop_k, top_ptop_p)probs F.softmax(logits, dim1)return probsdef sample(probs: torch.Tensor, num_samples: int 1):idx_next torch.multinomial(probs, num_samplesnum_samples)if (idx_next.item() 0):raise RuntimeErrorreturn idx_nextdef norm_logits_whole(logits: torch.Tensor, temperature: float, top_k: float, top_p: float) - torch.Tensor:# 优化一版提速 矩阵计算代替原来的for循环assert logits.dim() 3Args:logits (torch.Tensor): shape (1, vocab)temperature (float): temperaturetop_k (float): top_ktop_p (float): top_pReturns:torch.Tensor: next token with shape as (batch, 1)def top(scores: torch.FloatTensor, top_k: int 0, top_p: float 0.0):sorted_logits, sorted_indices torch.sort(scores, descendingFalse)cumulative_probs sorted_logits.softmax(dim-1)cumulative_probs cumulative_probs.cumsum(dim-1)# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)sorted_indices_to_remove cumulative_probs (1 - top_p)# Keep at least min_tokens_to_keepsorted_indices_to_remove[..., -1:] 0# indices_to_remove sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)# 这个函数比较难理解、三维矩阵的scatter——就是把sorted_indices_to_remove 用 sorted_indices指定的index来还原# print(sorted_indices,sorted_indices)# print(sorted_indices_to_remove, sorted_indices_to_remove)indices_to_remove sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)# print(indices_to_remove, indices_to_remove)scores scores.masked_fill(indices_to_remove, -float(Inf))return scoresif top_k 0:filter torch.topk(logits, min(2, logits.size(-1)))[0]logits[logits filter[..., [-1]]] float(-inf)logits logits / temperaturelogits top(logits, top_k, top_p)probs F.softmax(logits, dim2)return probsdef max_fn(x):norm(max (x, 0))x_max torch.where(x 0, x, torch.zeros_like(x))x_max_sum torch.sum(x_max, dim1, keepdimTrue)return x_max / x_max_sumclass KVCachedGenration(object):def __init__(self, model, eos_token, temperature: float 1.0, top_k: int 0, top_p: float 0, do_sampleFalse):self.model modelself.eos_token eos_tokenself.past_key_values Noneself.prob_history Noneself.do_sample do_sampleself.temperature temperatureself.top_k top_kself.top_p top_pdef get_position_ids(self, input_ids, device):batch_size, seq_length input_ids.shapeposition_ids torch.arange(seq_length, dtypetorch.long, devicedevice).unsqueeze(0).repeat(batch_size, 1)return position_idstorch.no_grad()def generate(self, input_ids, gamma, debugFalse):eos_index Nonex input_idsfor index, _ in enumerate(range(gamma)):if self.past_key_values:kv_cached_len self.past_key_values[0][0].shape[0]position_ids self.get_position_ids(x, x.device)position_ids position_ids[..., kv_cached_len:]# ChatGLM一定要输入位置编码一般推理的时候模型内部会自己配置位置编码的t1 time.time()outputs self.model(x[:, kv_cached_len:], position_idsposition_ids,past_key_valuesself.past_key_values, use_cacheTrue)t2 time.time()if debug:print(fpost inference time {round((t2 - t1) * 1000, 4)} ms, data shape {x[:, kv_cached_len:].shape})logits outputs.logitst1 time.time()for i in range(logits.shape[1]):logits[:, i, :] norm_logits(logits[:, i, :], self.temperature, self.top_k, self.top_p)t2 time.time()if debug:print(fnorm_logits time {round((t2 - t1) * 1000, 4)} ms, data shape {x[:, kv_cached_len:].shape})probs logits[:, -1, :]self.past_key_values outputs.past_key_valuesself.prob_history torch.cat([self.prob_history, logits], dim1)else:t1 time.time()outputs self.model(x)t2 time.time()if debug:print(ffirst inference time {round((t2 - t1) * 1000, 4)} ms, data shape {x.shape})self.prob_history outputs.logitst1 time.time()for i in range(self.prob_history.shape[1]):self.prob_history[:, i, :] norm_logits(self.prob_history[:, i, :], self.temperature, self.top_k,self.top_p)t2 time.time()if debug:print(fnorm_logits time {round((t2 - t1) * 1000, 4)} ms, data shape {self.prob_history.shape})self.past_key_values outputs.past_key_valuesprobs self.prob_history[:, -1, :]if self.do_sample:next_tok sample(probs)else:next_tok torch.argmax(probs, dim-1).unsqueeze(0)x torch.cat((x, next_tok), dim1)if next_tok self.eos_token and eos_index is None:eos_index indexreturn x, eos_indextorch.no_grad()def generate_promotion(self, input_ids, gamma, debugFalse):eos_index Nonex input_idsfor index, _ in enumerate(range(gamma)):if self.past_key_values:kv_cached_len self.past_key_values[0][0].shape[0]position_ids self.get_position_ids(x, x.device)position_ids position_ids[..., kv_cached_len:]# ChatGLM一定要输入位置编码一般推理的时候模型内部会自己配置位置编码的t1 time.time()outputs self.model(x[:, kv_cached_len:], position_idsposition_ids,past_key_valuesself.past_key_values, use_cacheTrue)t2 time.time()if debug:print(fpost inference time {round((t2 - t1) * 1000, 4)} ms, data shape {x[:, kv_cached_len:].shape})logits outputs.logits# 优化一版提速t1 time.time()logits norm_logits_whole(logits, self.temperature, self.top_k, self.top_p)t2 time.time()if debug:print(fnorm_logits_whole time {round((t2 - t1) * 1000, 4)} ms, data shape {x[:, kv_cached_len:].shape})probs logits[:, -1, :]self.past_key_values outputs.past_key_valuesself.prob_history torch.cat([self.prob_history, logits], dim1)else:t1 time.time()outputs self.model(x)t2 time.time()if debug:print(ffirst inference time {round((t2 - t1) * 1000, 4)} ms, data shape {x.shape})self.prob_history outputs.logits# 优化一版提速t1 time.time()self.prob_history norm_logits_whole(self.prob_history, self.temperature, self.top_k, self.top_p)t2 time.time()if debug:print(fnorm_logits_whole time {round((t2 - t1) * 1000, 4)} ms, data shape {self.prob_history.shape})self.past_key_values outputs.past_key_valuesprobs self.prob_history[:, -1, :]if self.do_sample:next_tok sample(probs)else:next_tok torch.argmax(probs, dim-1).unsqueeze(0)x torch.cat((x, next_tok), dim1)if next_tok self.eos_token and eos_index is None:eos_index indexreturn x, eos_indextorch.no_grad()def kv_rollback(self, end_index):# kv cache 回滚到有效生成位置# ChatGLM k, v (seq, batch, head, hidden_dim)# for i in range(len(self.past_key_values)):# self.past_key_values[i][0] self.past_key_values[i][0][:, :, :end_index, :]# self.past_key_values[i][1] self.past_key_values[i][1][:, :, :end_index, :]# self.prob_history self.prob_history[:, :end_index, :]# kv cache 回滚到有效生成位置# ChatGLM k, v (seq, batch, head, hidden_dim)past_key_values_keeps []for index, kv in enumerate(self.past_key_values):k, v kvk k[:end_index, :, :, :]v v[:end_index, :, :, :]kv (k, v)past_key_values_keeps.append(kv)self.past_key_values past_key_values_keepsself.prob_history self.prob_history[:, :end_index, :]class SpeculateDecoding(object):def __init__(self, approx_model, target_model, tokenizer, temperature1.0, top_k0, top_p0.85, max_len100,random_seedNone):self.approx_model approx_modelself.target_model target_modelself.tokenizer tokenizerself.device approx_model.deviceself.random_seed random_seedself.temperature temperatureself.top_k top_kself.top_p top_pself.max_len max_lendef speculative_decode_generate(self, prompt, gamma, do_sampleFalse, debugFalse, promotionTrue):eos_token self.tokenizer.eos_token_idt1 time.time()approx_model_kv_cached_genration KVCachedGenration(self.approx_model, eos_token, self.temperature, self.top_k,self.top_p, do_sample)target_model_kv_cached_genration KVCachedGenration(self.target_model, eos_token, self.temperature, self.top_k,self.top_p, do_sample)t2 time.time()if debug:print(fcached_genration init {round((t2 - t1) * 1000, 4)} ms)t1 time.time()inputs self.tokenizer([prompt], return_tensorspt, paddingTrue)t2 time.time()if debug:print(ftoken2ids {round((t2 - t1) * 1000, 4)} ms)t1 time.time()prefix inputs[input_ids].to(self.device)t2 time.time()if debug:print(ftensor to gpu {round((t2 - t1) * 1000, 4)} ms)seq_len prefix.shape[1]reject_count 0T seq_len self.max_lenend Tassert prefix.shape[0] 1, input batch size must be 1eos_token_index Nonestart_t time.time()with torch.no_grad():while prefix.shape[1] T:prefix_len prefix.shape[1]if promotion:x, index approx_model_kv_cached_genration.generate_promotion(prefix, gamma, debug)else:x, index approx_model_kv_cached_genration.generate(prefix, gamma, debug)if debug:print(小模型生成的token, [self.tokenizer.decode(x[:, prefix_len:].tolist()[0])], index: ,x[:, prefix_len:].tolist()[0], eos index:, index)if index is not None:eos_token_index prefix_len indexif promotion:t_x, _ target_model_kv_cached_genration.generate_promotion(x, 1, debug)else:t_x, _ target_model_kv_cached_genration.generate(x, 1, debug)n prefix_len gamma - 1for i in range(gamma):if self.random_seed:torch.manual_seed(self.random_seed)r torch.rand(1, deviceself.device)else:r torch.as_tensor(1.0, dtypetorch.float, deviceself.device)j x[:, prefix_len i]# 拒接 以1-p/q 的概率拒绝 生成if r (target_model_kv_cached_genration.prob_history[:, prefix_len i - 1,j] / approx_model_kv_cached_genration.prob_history[:, prefix_len i - 1, j]):n prefix_len i - 1reject_count 1if debug:print(被拒绝的token, [self.tokenizer.decode(j)], index: , j, 被拒绝的位置, i)print(target_model,target_model_kv_cached_genration.prob_history[:, prefix_len i - 1, j])print(approx_model,approx_model_kv_cached_genration.prob_history[:, prefix_len i - 1, j])breakapprox_model_kv_cached_genration.kv_rollback(n 1)prefix x[:, :n 1]if debug:print(eos_token_index:, eos_token_index)print(n:, n)# # 如果小模型生成的eos不被接受则继续生成# if eos_token_index and eos_token_index n:# eos_token_index None## # 如果小模型生成的eos被接受则停止# if eos_token_index and eos_token_index n:# breakif eos_token_index is not None:# 如果小模型生成的eos不被接受则继续生成if eos_token_index n:eos_token_index Noneelse:# 如果小模型生成的eos被接受则停止break# 如果拒绝一些小模型生成的token从大模型中采样if n prefix_len gamma - 1:if do_sample:t sample(max_fn(target_model_kv_cached_genration.prob_history[:, n,:] - approx_model_kv_cached_genration.prob_history[:, n, :]))else:# t torch.argmax(max_fn(target_model_kv_cached_genration.prob_history[:, n,# :] - approx_model_kv_cached_genration.prob_history[:, n, :]),# dim-1).unsqueeze(0)# 直接取大模型logits最大的那个tokent torch.argmax(target_model_kv_cached_genration.prob_history[:, n, :], dim-1).unsqueeze(0)if debug:print(拒绝后大模型重新生成的token, [self.tokenizer.decode(t.tolist()[0])], index: , t.tolist()[0])target_model_kv_cached_genration.kv_rollback(n 1)else:if do_sample:t sample(target_model_kv_cached_genration.prob_history[:, -1, :])else:t torch.argmax(target_model_kv_cached_genration.prob_history[:, -1, :], dim-1).unsqueeze(0)target_model_kv_cached_genration.kv_rollback(n 2)if debug:print(接受后大模型生成的next token, [self.tokenizer.decode(t.tolist()[0])], index: , t.tolist()[0])prefix torch.cat((prefix, t), dim1)if t eos_token:eos_token_index prefix.shape[1] - 1breakend_time time.time()if debug:print(fwhile loop time {round((end_time - start_t) * 1000, 4)} ms)if eos_token_index:end eos_token_index 1t1 time.time()input_ids inputs[input_ids].tolist()[0]assert len(input_ids) end, eos_token选择位置错误output_token_ids prefix.tolist()[0][len(input_ids):end]response self.tokenizer.decode(output_token_ids)response self.target_model.process_response(response)t2 time.time()if debug:print(fpost process {round((t2 - t1) * 1000, 4)} ms)return response, reject_count 三、效果实测 这一节我们开始效果实测看看具体到实际业务落地效率和准确率指标gama参数(小模型一次生成多少个token)和logits归一化优化的影响等。 测试环境Linux系统、显卡4090显卡、基于ChatGLM6B进行了业务微调得到大模型、花大力气蒸馏了6B得到1.5B小模型。 业务测试数据1088条样例如下 从坐席和客户的对话中提取出空调品牌、空调样式、是否5匹、故障类型、姓名、服务时间、联系方式和地址等等几个重要的信息。注意由于我们业务场景需要准确性不需要模型发散输出多样性我们解码的时候采用greedy search 解码策略也就是生成token的时候采用torch.argmax(logits)来获取下一个token。 1、效果对比 tiny模型和big模型耗时以及准确率如下 小模型 general.chat.tiny: 100%|████████████████████| 1088/1088 [09:0300:00, 2.00it/s] tiny cost total time is: 543.863 S, each time is: 499.87412222168024 ms total 8704 correct 8539 items all acc 0.9810431985294118 {姓名: 0.9898897058823529, 服务时间: 0.9476102941176471, 联系方式: 0.9908088235294118, 地址: 0.9448529411764706, 空调品牌: 0.9963235294117647, 空调样式: 0.9926470588235294, 是否5匹: 0.9972426470588235, 故障类型: 0.9889705882352942} glm tiny acc: 0.859375 大模型 general.chat.common: 100%|█████████| 1088/1088 [31:1200:00, 0.58it/s] common cost total time is: 1871.7942 S, each time is: 1720.3990845264123 ms total 8704 correct 8564 items all acc 0.9839154411764706 {姓名: 0.9880514705882353, 服务时间: 0.9595588235294118, 联系方式: 0.9871323529411765, 地址: 0.9604779411764706, 空调品牌: 0.9926470588235294, 空调样式: 0.9954044117647058, 是否5匹: 0.9954044117647058, 故障类型: 0.9926470588235294} glm common acc: 0.8759191176470589 可以看到我们蒸馏后的小模型效果也很好和大模型差距不大只有1.65个百分点的差距耗时小模型每次推理500ms左右大模型每次1720ms小模型的速度大概是大模型的3.5倍。 看speculate sampling算法 logits归一化未优化  gama7 r troch.rand(1) 大模型argmax生成下一个token 以1-p(x)/q(x)的概率拒绝小模型的token gama7 general.chat.speculatedecode: 100%|█████████| 1088/1088 [15:1400:00, 1.19it/s] speculate cost total time is: 914.1663 S, each time is: 840.2263719369383 ms total 8704 correct 8576 items all acc 0.9852941176470589 {姓名: 0.9880514705882353, 服务时间: 0.9604779411764706, 联系方式: 0.9898897058823529, 地址: 0.9641544117647058, 空调品牌: 0.9944852941176471, 空调样式: 0.9944852941176471, 是否5匹: 0.9972426470588235, 故障类型: 0.9935661764705882} glm speculate acc: 0.8869485294117647 logits归一化优化   gama7 r troch.rand(1) 大模型argmax生成下一个token 以1-p(x)/q(x)的概率拒绝小模型的token gama7 优化版本 general.chat.speculatedecode: 100%|█████████| 1088/1088 [12:5600:00, 1.40it/s] speculate promotion cost total time is: 776.8863 S, each time is: 714.0498757362366 ms total 8704 correct 8577 items all acc 0.9854090073529411 {姓名: 0.9880514705882353, 服务时间: 0.9613970588235294, 联系方式: 0.9898897058823529, 地址: 0.9641544117647058, 空调品牌: 0.9944852941176471, 空调样式: 0.9944852941176471, 是否5匹: 0.9972426470588235, 故障类型: 0.9935661764705882} glm speculate promotion acc: 0.8878676470588235 logits归一化优化   gama14 r troch.rand(1) 大模型argmax生成下一个token 以1-p(x)/q(x)的概率拒绝小模型的token gama14 优化版本 general.chat.speculatedecode: 100%|█████████| 1088/1088 [11:4100:00, 1.55it/s] speculate promotion cost total time is: 700.5773 S, each time is: 643.9129872357144 ms total 8704 correct 8574 items all acc 0.9850643382352942 {姓名: 0.9880514705882353, 服务时间: 0.9604779411764706, 联系方式: 0.9889705882352942, 地址: 0.9632352941176471, 空调品牌: 0.9944852941176471, 空调样式: 0.9954044117647058, 是否5匹: 0.9972426470588235, 故障类型: 0.9926470588235294} glm speculate promotion gama 14 acc: 0.8851102941176471 speculate sampling算法确实有效相对大模型而言确实能提升推理速度并完美保持大模型的生成质量。基线效果是提速1.05倍、优化logits归一化后提速1.41倍、gama14时提速1.67倍同时算法在测试集上的准确率均超过单一大模型准确率有偶然性这个和测试数据以及大模型和小模型的效果强相关当然不仅仅是准确率推理速度也是和上述因素强相关的。 另一方面gama增加的同时只要小模型效果还行那被拒绝的次数不太多的情况下一次推理过程大模型的forward次数就减少了那么提速是理所当然的了。 2、解码日志展示 最后我们看一下解码过程中debug的日志同一个输入 gama7  logits归一化未优化时 while loop time 795.3556 ms 可以看到大模型forward耗时是小模型forward耗时的3-5倍之间norm_logits的时间在首次forward后seq_len为352耗时80ms左右小模型和大模型都要执行一次后面大模型forward每次过8个tokennorm_logits为1.75ms左右。 gama7  logits归一化优化后 while loop time 691.3598 ms 每次inference大小模型forward基本没有变化变化的是norm_logits为优化后的norm_logits_whole首次forward后seq_len为352耗时0.9ms左右这块儿优化耗时巨大后面大模型forward每次过8个tokennorm_logits_whole为0.25ms左右也有不小的优化。总体上来看也有100ms左右的优化。 gama14  logits归一化优化后 while loop time 542.4328 ms 当gama14和gama7对比大模型forward一次15个token为25ms左右8个token也为25ms左右耗时差不多这个和GPU的矩阵并行计算有关但是整个解码算法过程中大模型forward次数为4gama7时大模型forward次数为8而小模型耗时确实基本不变因此整体上优化耗时为4*25ms100ms左右再加上其他的一些随机耗时差不多能到100-150ms左右。 以上是一个具体测试用例的分析Speculative Decoding算法的收益和大小两个模型的效果差异有明显关系相差不是太大才会有加速的效果同时小模型的效果又不能比大模型差太小不然直接使用小模型就好了。相对而言应用场景比较受限资源消耗也是需要两次训练资源推理的时候需要大小模型占用的显存。就本次效果测试来说ChatGLM6B大模型推理占用显存12.8G小模型4.5GSpeculative Decoding大小模型联合则占用17.5G左右。当然这个Speculative Decoding算法很优秀实现比较简单确实比较取巧的实现了大模型推理的加速效果同时生成质量完美保留。 最后的最后对推理速度和效果都卡的比较死的业务场景对这个算法有一定的需求。 参考文章 大模型推理妙招—投机采样Speculative Decoding Fast Inference from Transformers via Speculative Decoding Accelerating Large Language Model Decoding with Speculative Sampling LLMSpeculativeSampling代码
http://www.yutouwan.com/news/131634/

相关文章:

  • 广州网站建设电话咨询iis7 wordpress 伪静态
  • 深圳骏域网站建设专家88宁德蕉城城乡建设网站
  • 有友情链接的网站西安网站开发工资
  • 网站建设到运营赚钱投资公司的钱从哪里来
  • 传奇新开网站服有啥可以自己做网站的软件
  • 宁波市江东区地块建设网站怎样进入国外网站
  • wordpress 多域名多站点建立网站需要多少钱多少钱28湖南岚鸿
  • 流行网站设计金融服务网站建设
  • 哪里有工程做的网站wordpress 改登录界面
  • 昆明网站搜索引擎优化vestacp wordpress
  • 网站模板 哪个好pinterest图片wordpress
  • 做爰全的网站自己可以做开奖网站吗
  • 网站开发需要那些技术人员网站关键词用什么符号
  • 企业加盟网站建设开发小程序软件
  • go网站开发网站备案信息如何注销吗
  • 西安网站设计培训试听深圳网站建设最专业的
  • 网站系统繁忙是什么原因中仑建设网站
  • 家庭宽带做网站做网站需要准备的东西
  • 摄影网站网址大全外呼电销系统
  • 男和男做那个视频网站网站建设类文章
  • 做网站优化的工资有多高四川禾力建设工程质量检测有限公司网站
  • 专业俄文网站建设网站建设创建
  • 临汾工程建设招标投标网站发稿计划
  • 北京金港建设股份有限公司网站自己做手机网站
  • 网站编程课程设计心得体会怎么介绍做网站技术
  • 网站后台管理系统管理员登录深圳公司网站设计
  • 番禺网站建设知乎合肥百度推广排名优化
  • 网站设计经典案例欣赏免费注册入口
  • 知乎 拒绝 朋友 做网站论坛网站开发平台
  • 高境网站建设网站编辑专题怎么做