transformer

package module
v0.0.0-...-362ad42 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 5, 2024 License: Apache-2.0 Imports: 3 Imported by: 0

README

Transformer LicenseGo.Dev referenceTravis CIGo Report Card

Overview

transformer is pure Go package to facilitate applying Natural Language Processing (NLP) models train/test and inference in Go.

This package is in active mode of building and there are many changes ahead. Hence you can use it with your complete own risk. The package will be considered as stable when version 1.0 is released.

transformer is heavily inspired by and based on the popular Python HuggingFace Transformers. It's also influenced by Rust version - rust-bert. In fact, all pre-trained models for Rust are compatible to import to this Go transformer package as both rust-bert's dependency Pytorch Rust binding - tch-rs and Go binding gotch are built with similar principles.

transformer is part of an ambitious goal (together with tokenizer and gotch) to bring more AI/deep-learning tools to Gophers so that they can stick to the language they love and good at and build faster software in production.

Dependencies

2 main dependencies are:

  • tokenizer
  • gotch

Prerequisites and installation

  • As this package depends on gotch which is a Pytorch C++ API binding for Go, a pre-compiled Libtorch copy (CPU or GPU) should be installed in your machine. Please see gotch installation instruction for detail.
  • Install package: go get -u github.com/sugarme/transformer

Basic example

    import (
        "fmt"
        "log"

        "github.com/sugarme/gotch"
        ts "github.com/sugarme/gotch/tensor"
        "github.com/sugarme/tokenizer"

        "github.com/sugarme/transformer/bert"
    )

    func main() {
        var config *bert.BertConfig = new(bert.BertConfig)
        if err := transformer.LoadConfig(config, "bert-base-uncased", nil); err != nil {
            log.Fatal(err)
        }

        var model *bert.BertForMaskedLM = new(bert.BertForMaskedLM)
        if err := transformer.LoadModel(model, "bert-base-uncased", config, nil, gotch.CPU); err != nil {
            log.Fatal(err)
        }

        var tk *bert.Tokenizer = bert.NewTokenizer()
        if err := tk.Load("bert-base-uncased", nil); err != nil{
            log.Fatal(err)
        }

        sentence1 := "Looks like one [MASK] is missing"
        sentence2 := "It was a very nice and [MASK] day"

        var input []tokenizer.EncodeInput
        input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence1)))
        input = append(input, tokenizer.NewSingleEncodeInput(tokenizer.NewInputSequence(sentence2)))

        encodings, err := tk.EncodeBatch(input, true)
        if err != nil {
            log.Fatal(err)
        }

        var maxLen int = 0
        for _, en := range encodings {
            if len(en.Ids) > maxLen {
                maxLen = len(en.Ids)
            }
        }

        var tensors []ts.Tensor
        for _, en := range encodings {
            var tokInput []int64 = make([]int64, maxLen)
            for i := 0; i < len(en.Ids); i++ {
                tokInput[i] = int64(en.Ids[i])
            }

            tensors = append(tensors, ts.TensorFrom(tokInput))
        }

        inputTensor := ts.MustStack(tensors, 0).MustTo(device, true)
        var output ts.Tensor
        ts.NoGrad(func() {
            output, _, _ = model.ForwardT(inputTensor, ts.None, ts.None, ts.None, ts.None, ts.None, ts.None, false)
        })
        index1 := output.MustGet(0).MustGet(4).MustArgmax(0, false, false).Int64Values()[0]
        index2 := output.MustGet(1).MustGet(7).MustArgmax(0, false, false).Int64Values()[0]

        got1, ok := tk.IdToToken(int(index1))
        if !ok {
            fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index1)
        }
        got2, ok := tk.IdToToken(int(index2))
        if !ok {
            fmt.Printf("Cannot find a corresponding word for the given id (%v) in vocab.\n", index2)
        }

        fmt.Println(got1)
        fmt.Println(got2)
        
        // Output:
        // person
        // pleasant
    }

Getting Started

License

transformer is Apache 2.0 licensed.

Acknowledgement

Documentation

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func LoadConfig

func LoadConfig(config pretrained.Config, modelNameOrPath string, customParams map[string]interface{}) error

LoadConfig loads pretrained configuration data from local or remote file.

Parameters: - `config` pretrained.Config (any model config that implements pretrained `Config` interface) - `modelNameOrPath` is a string of either

  • Model name or
  • File name or path or
  • URL to remote file

If `modelNameOrPath` is resolved, function will cache data using `TransformerCache` environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory. If `modleNameOrPath` is valid URL, file will be downloaded and cached. Finally, configuration data will be loaded to `config` parameter.

Example
package main

import (
	"fmt"
	"log"

	"github.com/yinziyang/transformer"
	"github.com/yinziyang/transformer/bert"
)

func main() {
	modelNameOrPath := "bert-base-uncased"
	var config bert.BertConfig
	err := transformer.LoadConfig(&config, modelNameOrPath, nil)
	if err != nil {
		log.Fatal(err)
	}

	fmt.Println(config.VocabSize)

}
Output:

30522

func LoadModel

func LoadModel(model pretrained.Model, modelNameOrPath string, config pretrained.Config, customParams map[string]interface{}, device gotch.Device) error

LoadConfig loads pretrained model data from local or remote file.

Parameters: - `model` pretrained Model (any model type that implements pretrained `Model` interface) - `modelNameOrPath` is a string of either

  • Model name or
  • File name or path or
  • URL to remote file

If `modelNameOrPath` is resolved, function will cache data using `TransformerCache` environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory. If `modleNameOrPath` is valid URL, file will be downloaded and cached. Finally, model weights will be loaded to `varstore`.

func LoadTokenizer

func LoadTokenizer(tk pretrained.Tokenizer, modelNameOrPath string, customParams map[string]interface{}) error

LoadTokenizer loads pretrained tokenizer from local or remote file.

Parameters: - `tk` pretrained.Tokenizer (any tokenizer model that implements pretrained `Tokenizer` interface) - `modelNameOrPath` is a string of either

  • Model name or
  • File name or path or
  • URL to remote file

If `modelNameOrPath` is resolved, function will cache data using `TransformerCache` environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory. If `modleNameOrPath` is valid URL, file will be downloaded and cached. Finally, vocab data will be loaded to `tk`.

Types

This section is empty.

Directories

Path Synopsis
example
ner
qa

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL