couplet

command
v1.0.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 15, 2023 License: MIT Imports: 10 Imported by: 0

README

对对联

该示例使用GPT模型来进行自动对对联,训练过程中使用开源数据集couplet-dataset进行训练,最终效果如下

$ go run main.go evaluate --model model7M 晚风摇树树还挺
load embedding...
model loaded
inputs: [472 3 462 148 148 342 1516]
map[4.278747:[醉] 5.084207:[润] 8.868446:[晨]]
map[3.8447263:[花] 4.750472:[润] 8.635651:[露]]
map[5.46043:[花] 6.7003703:[露] 10.768249:[润]]
map[4.3850584:[露] 4.875666:[润] 9.896332:[花]]
map[3.6241615:[红] 5.611262:[润] 10.782802:[花]]
map[4.3855276:[花] 5.48069:[红] 9.480111:[更]]
map[3.7904112:[心] 4.269902:[花] 10.3220415:[红]]
晨露润花花更红

$ go run main.go evaluate --model model7M 投石向天跟命斗
load embedding...
model loaded
inputs: [1233 190 383 11 2623 620 490]
map[5.7068815:[门] 5.7826476:[问] 9.79136:[闭]]
map[3.0136497:[问] 3.1092193:[人] 8.903796:[门]]
map[3.021591:[还] 3.448888:[歌] 8.96453:[问]]
map[4.9368696:[地] 5.7390223:[时] 9.438878:[卷]]
map[3.5542138:[话] 3.858942:[时] 8.253393:[与]]
map[3.025545:[与] 3.2461479:[卷] 9.06726:[时]]
map[4.250452:[时] 4.712057:[舟] 10.401218:[争]]
闭门问卷与时争

$ go run main.go evaluate --model model7M 我是谁
load embedding...
model loaded
inputs: [85 62 191]
map[4.3809786:[雨] 4.9436274:[染] 7.105626:[绿]]
map[3.8163047:[水] 4.013789:[东] 4.088595:[得]]
map[4.872726:[唱] 5.4107614:[兰] 6.3983927:[发]]
绿得发

共计751万参数,词表大小4436个字(只训练了前1万个样本)

+------------------------+---------+
|          NAME          |  COUNT  |
+------------------------+---------+
| transformer0_attention |    1872 |
| transformer0_dense     | 1256640 |
| transformer0_output    | 1254960 |
| transformer1_attention |    1872 |
| transformer1_dense     | 1256640 |
| transformer1_output    | 1254960 |
| output                 | 2488596 |
| total                  | 7515540 |
+------------------------+---------+

train 200, cost=2h15m7.877395694s, loss=3.665343e-02

模型参数配置

const embeddingDim = 8 // 8个float32表示一个字向量
const paddingSize = 70 // 最长为34*2,因此padding长度必须大于68
const heads = 4
const batchSize = 128
const epoch = 200
const lr = 0.001
const transformerSize = 2

编译

调整logic/model/params.go中的参数后使用以下命令进行编译

go build

模型训练

# 下载数据集
./couplet download
# 对数据集进行裁剪,提高训练速度
./couplet cut 10000
# 模型训练
./couplet train

模型推理

$ ./couplet evaluate --model ./model7M 丹枫江冷人初去
load embedding...
model loaded
inputs: [338 756 51 394 6 543 155]
map[4.8344746:[写] 7.0445685:[柳] 10.421665:[绿]]
map[3.87284:[堤] 4.2659774:[叶] 9.011678:[柳]]
map[4.659379:[发] 4.795504:[新] 9.285393:[堤]]
map[4.7904825:[柳] 5.4895434:[堤] 10.241007:[新]]
map[2.8209953:[现] 3.400415:[复] 8.911065:[燕]]
map[3.1597254:[燕] 5.2025476:[心] 10.825405:[复]]
map[3.1378353:[环] 4.011742:[红] 10.709359:[来]]
绿柳堤新燕复来

由于GPT模型是一个字一个字进行推理的,因此输出内容中的每一行表示该位置上的输出字概率

Documentation

The Go Gopher

There is no documentation for this package.

Directories

Path Synopsis
logic

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL