第一句子大全,网罗天下好句子,好文章尽在本站!

当莎士比亚遇见Google Flax:教你用字符级语言模型和归递神经网络写“莎士比亚”式句子

时间:2022-11-14

威廉·莎士比亚第十二夜在几个月前,谷歌的研究人员介绍了机器学习领域的一颗新星Flax

友情提示:本文共有 7955 个字,阅读大概需要 16 分钟。

作者 | Fabian Deuser

译者 | 天道酬勤 责编 | Carol

有些人生来伟大,有些人成就伟大,而另一些人则拥有伟大。—— 威廉·莎士比亚《第十二夜》

在几个月前,谷歌的研究人员介绍了机器学习领域的一颗新星——Flax。从那以后发生了很多事情,预发行版有了巨大的改进。作者自己在Flax上进行的CNNs实验已经取得了成果,与Tensorflow相比,它的灵活性仍然非常好。

今天作者将展示递归神经网络(RNNs)在Flax中的一个应用:字符级语言模型。

在许多学习任务中,我们不必考虑对先前输入的时间依赖性。

但是如果我们没有独立的固定大小的输入和输出向量,该怎么办呢?如果我们有向量序列呢?解决方案是递归神经网络。它们允许我们对下面描述的向量序列进行操作。

递归神经网络

在上图中,你可以看到不同类型的输入输出结构:

一对一是典型CNN或多层感知器,一个输入向量映射到一个输出向量。一对多是用于图像字幕的RNN体系结构。输入是图像,输出是描述图像的单词序列。多对多:第一种体系结构利用输入序列到输出序列进行机器翻译,如(德语译成英语)。第二个是适用于帧级别的视频字幕。RNNs 的主要优点是它们不仅依赖于当前输入,而且还依赖于先前的输入。

RNN是一个具有内部隐藏状态h的单元,该状态根据隐藏的大小用零初始化。在每个时间步长t中,我们将输入x_t插入到RNN单元中,并更新隐藏状态。如今,在下一个时间步t +1中,隐藏状态不再用零初始化,而是使用先前的隐藏状态进行初始化。因此,RNN允许保留有关几个时间步长的信息并生成序列。

字符级语言模型

有了这些新知识,我们现在需要为RNN构建第一个应用程序。字符级语言模型是许多任务的基础,例如图片字幕或文本生成。RNN单元的输入是字符序列形式的大量文本。现在的训练任务是学习在给定先前字符序列的情况下如何预测下一个字符。因此,我们在每个时间步长t生成一个字符,而我们先前的字符是x_t-1,x_t-2…。

举例来说,让我们以FUZZY一词作为训练序列,现在的词汇为{"f","u","z","y"}。由于RNN仅适用于向量,因此我们将所有字符转换为所谓的“单热向量”。单热向量由零组成,其中一个基于词表中的位置为一个,对于“Z”,转换后的向量为[0,0,1,0]。

在下图中,你可以看到给定输入“ FUZZ”的示例,我们希望预测单词“ UZZY”的结尾。神经元的隐藏大小为4,我们希望输出层中的绿色数字较高,而红色为较低。

编程

作者在上一篇有关CNNs的文章中解释了Flax的一些基本概念。作为数据集,我们使用类似这样的对话组成莎士比亚的作品:

EDWARD:Tis even so; yet you are Warwick still.GLOUCESTER:Come, Warwick, take the time; kneel down, kneel down: Nay, when? strike now, or else the iron cools.

我们再次使用Google Colab进行训练,因此我们必须再次安装必要的PIP-Packages:

pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*).([0-9]*),.*/cuda12/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*).([0-9]*).*/cp12/p"`-none-linux_x86_64.whl jaxpip install -q git+https://github.com/google/flax.git@master因为训练任务非常艰巨,你应该使用具有GPU支持的运行。你可以使用以下命令测试是否存在GPU支持:

from jax.lib import xla_bridgeprint(xla_bridge.get_backend.platform)

现在我们准备从头开始创建RNN:

class RNN(flax.nn.Module):"""LSTM"""def apply(self, carry, inputs):carry1, outputs = jax_utils.scan_in_dim(nn.LSTMCell.partial(name="lstm1"), carry[0], inputs, axis=1)carry2, outputs = jax_utils.scan_in_dim(nn.LSTMCell.partial(name="lstm2"), carry[1], outputs, axis=1)carry3, outputs = jax_utils.scan_in_dim(nn.LSTMCell.partial(name="lstm3"), carry[2], outputs, axis=1)x = nn.Dense(outputs, features=params["vocab_length"], name="dense")return [carry1, carry2, carry3], x在这样的实际训练情况下,我们不使用普通的RNN单元,而是使用LSTM单元。这是更进一步的发展,可以更好地解决梯度消失的问题。为了获得更高的精度,我们使用了三个堆叠的LSTM单元。我们将第一个单元的输出传递给下一个单元,并用自己的隐藏状态初始化每个LSTM单元,这一点非常重要。否则,我们将无法追踪时间依赖性。

最后一个LSTM单元的输出提供给我们密集层。密集层的词汇量和我们词汇量相当。在前面的“模糊”示例中,神经元的数量为四个。如果将“ FUZZ”设置为RNN的输入,则神经元最多产生类似于[1.7,0.1,-1.0,3.1]这样的输出,因为此输出表明“ Y”是最可能的字符。

因为我们有两种不同的模式,所以针对不同的情况,我们将RNN包装在另一个模块中。

class charRNN(flax.nn.Module):"""Char Generator"""def apply(self, inputs, carry_pred=None, train=True):batch_size = params["batch_size"]vocab_size = params["vocab_length"]hidden_size = 512if train:carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)carry = [carry1, carry2, carry3]_, x = RNN(carry, inputs)return xelse:carry, x = RNN(carry_pred, inputs)return carry, x这种情况是:

训练模型,我们要学习如何预测。预测模型,实际上在这里我们采样一些文本。在训练模型之前,我们需要使用以下函数创建它:

def create_model(rng):"""Creates a model."""vocab_size = params["vocab_length"]_, initial_params = charRNN.init_by_shape(rng, [((1, params["seq_length"], vocab_size), jnp.float32)])model = nn.Model(charRNN, initial_params)return model我们每个序列长度为50个字符,词汇表包含65个不同的字符。

作为RNN的优化程序,为了避免初始权重过大,我们选择了初始学习率为0.002且权重衰减的Adam优化器。

def create_optimizer(model, learning_rate):"""Creates an Adam optimizer for model."""optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=1e-1)optimizer = optimizer_def.create(model)return optimizer

训练模型

在训练模型下,我们将32个序列的批次输入到RNN中。每个序列均取自我们的数据集,并包含两个子序列,一个是子序列的字符从0到49,另一个子序列的字符从1到50。通过这种简单的拆分,我们的网络可以学习到最有可能的下一个字符。在每一批中,我们初始化隐藏状态,并将序列提供给我们的RNN。

@jax.jit

def train_step(optimizer, batch):

"""Train one step."""

def loss_fn(model):

"""Compute cross-entropy loss and predict logits of the current batch"""

logits = model(batch[0])

loss = jnp.mean(cross_entropy_loss(logits, batch[1])) / params["batch_size"]

return loss, logits

def exponential_decay(steps):

"""Decrease the learning rate every 5 epochs"""

x_decay = (steps / params["step_decay"]).astype("int32")

ret = params["learning_rate"]* jax.lax.pow((params["learning_rate_decay"]), x_decay.astype("float32"))

return jnp.asarray(ret, dtype=jnp.float32)

current_step = optimizer.state.step

new_lr = exponential_decay(current_step)

# calculate and apply the gradient

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

(_, logits), grad = grad_fn(optimizer.target)

new_optimizer = optimizer.apply_gradient(grad, learning_rate=new_lr)

metrics = compute_metrics(logits, batch[1])

metrics["learning_rate"] = new_lr

return new_optimizer, metrics

在我们的训练方法中有两个子函数。loss_fn通过将被解释为向量的输出神经元与所需的单热向量进行比较来计算交叉熵损失。因此在“模糊”示例中,我们将有一个输出[1.7,0.1,-1.0,3.1]和一个热向量[0,0,0,1]。现在我们使用以下公式计算损失:

我们不得不从CNN示例中重写一些代码,因为我们现在使用的不是简单类的序列:

@jax.vmap

def cross_entropy_loss(logits, labels):

"""Returns cross-entropy loss."""

return -jnp.mean(jnp.sum(nn.log_softmax(logits) * labels))

训练步骤中的另一种方法是exponential_decay。我们使用的是Adam优化器,初始学习率为0.002。为了避免太强烈的振荡,我们想每五个周期降低学习率。在每五个周期之后,因子0.97乘以我们的初始学习率,x是多长时间我们达到五个时期。

你将再次看到Flax的优势,即以轻松灵活的方式集成自己的学习速率调度程序。

预测模型

现在我们要评估学习模型,因此我们从词汇表中选择一个随机字符作为切入点。像在训练中一样,我们初始化隐藏状态,但是这次只是在采样开始时。现在子函数推断将一个字符作为输入。对于隐藏状态,我们在每个时间步长后输出,并在下一个时间步长中将它们输入到RNN中。因此,我们不会失去时间依赖性。

@jax.jit

def sample(inputs, optimizer):

next_inputs = inputs

output =

batch_size = 1

carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)

carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)

carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)

carry = [carry1, carry2, carry3]

def inference(model, carry):

carry, rnn_output = model(inputs=next_inputs, train=False, carry_pred=carry)

return carry, rnn_output

for i in range(200):

carry, rnn_output = inference(optimizer.target, carry)

output.append(jnp.argmax(rnn_output, axis=-1))

# Select the argmax as the next input.

next_inputs = jnp.expand_dims(common_utils.onehot(jnp.argmax(rnn_output), params["vocab_length"]), axis=0)

return output

这种方法称为“贪婪采样”,因为我们总是取输出向量中概率最大的字符。还有更好的采样方法,比如波束搜索,在此就不做介绍。

训练和样本循环

至少我们可以在训练和样本循环中调用所有编写的函数。

def train_model:

"""Train and inference """

rng = jax.random.PRNGKey(0)

model = create_model(rng)

optimizer = create_optimizer(model, params["learning_rate"])

del model

for epoch in range(100):

for text in tfds.as_numpy(ds):

optimizer, metrics = train_step(optimizer, text)

print("epoch: %d, loss: %.4f, accuracy: %.2f, LR: %.8f" % (epoch+1,metrics["loss"], metrics["accuracy"] * 100, metrics["learning_rate"]))

test = test_ds(params["vocab_length"])

sampled_text = ""

if ((epoch+1)%10 == 0):

for i in test:

sampled_text += vocab[int(jnp.argmax(i.numpy(),-1))]

start = np.expand_dims(i, axis=0)

text = sample(start, optimizer)

for i in text:

sampled_text += vocab[int(i)]

print(sampled_text)

每隔10个周期后,我们会生成一个文本示例,并且在开始时看起来非常重复:

peak the mariners all the merchant of the meaning of the meaning of the meaning of the meaning of the meaning of the meaning…

但是我们变得越来越好,经过100个周期的训练,莎士比亚的作品似乎还活着,并在写新的文字!

This is a shift respected woman to the king"s forth,

To this most dangerous soldier there and fortune.

ANTONIO:

If she would concount a sight on honour

Of the moon, why,...

100个周期训练准确性为86.10%,我们的学习率降至0.00112123。

结论

字符级语言模型的基础是一个能够完成文本的强大工具,可以用作自动补全。可以用作自动补全。也可以利用这个概念来学习一篇文章的观点。但是,生成完整的新文本是一项非常困难的任务。

我们的模型输出的句子看起来像莎士比亚的文本,但它缺乏意义。大家也可以尝试用这种模型并根据有意义的输入创建更有意义的句子。

Flax功能强大且工具众多,但仍处于开发的初期阶段,但它们在开发我喜欢的框架方面处于良好的发展状态。真正巧妙的是,我们只需要稍微更改一下“旧” CNN代码即可在现有基础上使用RNN。

但是Flax仍然缺少它自己的输入管道,因此作者已经用Tensorflow编写了它。如果你想尝试使用作者的代码,你可以在Github Repo中找到用于数据集创建和完整RNN的代码(https://github.com/Skyy93/CharacterLevelModelFlax/)。原文:https://hackernoon.com/shakespeare-meets-googles-flax-8m1r34q9

本文为 AI 科技大本营翻译,转载请经授权。

今日福利

遇见陆奇

同样作为“百万人学 AI”的重要组成部分,2020 AIProCon 开发者万人大会将于 7 月 3 日至 4 日通过线上直播形式,让开发者们一站式学习了解当下 AI 的前沿技术研究、核心技术与应用以及企业案例的实践经验,同时还可以在线参加精彩多样的开发者沙龙与编程项目。参与前瞻系列活动、在线直播互动,不仅可以与上万名开发者们一起交流,还有机会赢取直播专属好礼,与技术大咖连麦。

本文如果对你有帮助,请点赞收藏《当莎士比亚遇见Google Flax:教你用字符级语言模型和归递神经网络写“莎士比亚”式句子》,同时在此感谢原作者。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。
相关阅读
认清这些 跨过小学说明文中的“陷阱”!

认清这些 跨过小学说明文中的“陷阱”!

...说明方法不外乎这么几种:列数字,打比方,作比较和举例子。列数字和打比方是比较容易理解的,特点很明显。但是“作比较”和“举例子”这两种说明方法,有的孩子经常容易搞混。说到两者的含义,作比较就是将两种相似...

2022-12-12 #经典句子

雅思大作文这样举例论证 才能让考官眼前一亮

雅思大作文这样举例论证 才能让考官眼前一亮

...眼前一亮呢?小编分析给你听→_→一、什么时候使用举例子当文中的观点是名词复数,抽象或概括性的名词时,可以用具体的实例(代表性、典型性)来进一步支持或是说明观点。二、举例子时易犯的错误1)例子不具备代表性...

2023-09-30 #经典句子

什么是排比句 举两个例子

什么是排比句 举两个例子

排比句指的是将三个或三个以上结构和长度类似、语气差不多、意义相关或相同的句子排列起来。排比句的作用:排比句的运用可以加强气势,还可以使文章的节奏感加强,条理性更好,更能够表达出强烈的感情举例:1. 爱是...

2022-11-11 #经典句子

长沙事业单位申论技巧:揭开文章论证神秘的“面纱”

长沙事业单位申论技巧:揭开文章论证神秘的“面纱”

...论证”的过程是同学们的难点,因为这个部分需要通过举例子或讲道理去进行支撑,而这就是大家所欠缺的,所以平时可通过观看时事政治、新闻热点等积累写作素材,只要有素材,作文论证就会很简单。如:一、例证法(举例...

2023-06-14 #经典句子

中考语文:说明方法要辨清 结合语境谈作用

中考语文:说明方法要辨清 结合语境谈作用

...的考查。常见的说明方法,必须掌握。特别强调的,如举例子、列数字、分类别、作比较、打比方、下定义、引用等。题型方面,值得注意的是,过去多考查单一的说明方法及其作用,近些年来有些地方又开始出现了考查一个语...

2022-12-04 #经典句子

12 英语语法·句型篇——五种基本句型

12 英语语法·句型篇——五种基本句型

...。主语+谓语,即构成一个最简单的句子。举一些简单的例子:I dance.She died.we agree.……二、句型2——主语+谓语+宾语句型2在句型1的基础上多了一个宾语,宾语是什么呢?还是从句子表达事情的角度看,可以理解为“谁,对谁怎...

2023-06-15 #经典句子

考研英语89分 是因为我掌握了这个技巧

考研英语89分 是因为我掌握了这个技巧

...间的方法。但是仍然有问题。背模板容易得模板分。举个例子,作文开头我们最喜欢用with the development of economy and technology,一旦阅卷老师看到这种开头,分数档就会往下调一档或者两档。抛开范文和模板,优雅地备考考研作文...

2023-01-27 #经典句子

复杂文段不要慌 句间关系来帮忙!

复杂文段不要慌 句间关系来帮忙!

...为止,第二句更为重要一些;第三句话“比如”就是在举例子,举出在气候变化的问题上,需要有科学专家去挖掘原因,举例子为了去说明第二句话,因此目前为止还是第二句话更为重要;第四句话,“但与此同时”,形式上出现...

2023-02-01 #经典句子

搞定英语语法你先要了解这4类词

搞定英语语法你先要了解这4类词

...说,它可以做定语,状语和补语。这里不想举太过复杂的例子,The man in red is reading books.The man in the house is reading books.The man is reading books for fun.这里分别用介词短语做了定语,状语和补语。而三大从句,分别充当名词,定语和状...

2017-11-22 #经典句子