bert

package
v0.7.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 24, 2021 License: BSD-2-Clause Imports: 46 Imported by: 0

Documentation

Overview

Package bert provides an implementation of BERT model (Bidirectional Encoder Representations from Transformers).

Reference: "Attention Is All You Need" by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin (2017) (http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)

Index

Constants

View Source
const (
	// DefaultConfigurationFile is the default BERT JSON configuration filename.
	DefaultConfigurationFile = "config.json"
	// DefaultVocabularyFile is the default BERT model's vocabulary filename.
	DefaultVocabularyFile = "vocab.txt"
	// DefaultModelFile is the default BERT spaGO model filename.
	DefaultModelFile = "spago_model.bin"
	// DefaultEmbeddingsStorage is the default directory name for BERT model's embedding storage.
	DefaultEmbeddingsStorage = "embeddings_storage"
)
View Source
const DefaultFakeLabel = "FAKE"

DefaultFakeLabel is the default value for the fake label used for BERT "discriminate" server requests.

View Source
const DefaultPredictedLabel = "PREDICTED"

DefaultPredictedLabel is the default value for the predicted label used for BERT "predict" server requests.

View Source
const DefaultRealLabel = "REAL"

DefaultRealLabel is the default value for the real label used for BERT "discriminate" server requests.

Variables

This section is empty.

Functions

func ConvertHuggingFacePreTrained

func ConvertHuggingFacePreTrained(modelPath string) error

ConvertHuggingFacePreTrained converts a HuggingFace pre-trained BERT transformer model to a corresponding spaGO model.

func Dump

func Dump(value interface{}, pretty bool) ([]byte, error)

Dump serializes the given value to JSON.

Types

type Answer

type Answer struct {
	Text       string    `json:"text"`
	Start      int       `json:"start"`
	End        int       `json:"end"`
	Confidence mat.Float `json:"confidence"`
}

Answer represent a single JSON-serializable BERT question-answering answer, used as part of a server's response.

type Answers added in v0.6.0

type Answers []Answer

Answers is a slice of Answer elements, which implements the sort.Interface.

func (Answers) Len added in v0.6.0

func (p Answers) Len() int

Len returns the length of the slice.

func (Answers) Less added in v0.6.0

func (p Answers) Less(i, j int) bool

Less returns true if the Answer.Confidence of the element at position i is lower than the one of the element at position j.

func (Answers) Sort added in v0.6.0

func (p Answers) Sort()

Sort sorts the Answers's elements by Answer.Confidence.

func (Answers) Swap added in v0.6.0

func (p Answers) Swap(i, j int)

Swap swaps the elements at positions i and j.

type Body

type Body struct {
	Text            string                                `json:"text"`
	Text2           string                                `json:"text2"`
	PoolingStrategy grpcapi.EncodeRequest_PoolingStrategy `json:"pooling_strategy"`
}

Body is the JSON-serializable expected request body for various BERT server requests.

type ClassConfidencePair

type ClassConfidencePair struct {
	Class      string    `json:"class"`
	Confidence mat.Float `json:"confidence"`
}

ClassConfidencePair associates a Confidence to a symbolic Class.

type Classifier

type Classifier struct {
	Config ClassifierConfig
	*linear.Model
}

Classifier implements a BERT Classifier.

func NewTokenClassifier

func NewTokenClassifier(config ClassifierConfig) *Classifier

NewTokenClassifier returns a new BERT Classifier model.

type ClassifierConfig

type ClassifierConfig struct {
	InputSize int
	Labels    []string
}

ClassifierConfig provides configuration settings for a BERT Classifier.

type ClassifyResponse

type ClassifyResponse struct {
	Class        string                `json:"class"`
	Confidence   mat.Float             `json:"confidence"`
	Distribution []ClassConfidencePair `json:"distribution"`
	// Took is the number of milliseconds it took the server to execute the request.
	Took int64 `json:"took"`
}

ClassifyResponse is a JSON-serializable server response for BERT "classify" requests.

type Config

type Config struct {
	HiddenAct             string            `json:"hidden_act"`
	HiddenSize            int               `json:"hidden_size"`
	IntermediateSize      int               `json:"intermediate_size"`
	MaxPositionEmbeddings int               `json:"max_position_embeddings"`
	NumAttentionHeads     int               `json:"num_attention_heads"`
	NumHiddenLayers       int               `json:"num_hidden_layers"`
	TypeVocabSize         int               `json:"type_vocab_size"`
	VocabSize             int               `json:"vocab_size"`
	ID2Label              map[string]string `json:"id2label"`
	Training              bool              `json:"training"` // Custom for spaGO
}

Config provides configuration settings for a BERT Model.

func LoadConfig

func LoadConfig(file string) (Config, error)

LoadConfig loads a BERT model Config from file.

type Discriminator

type Discriminator struct {
	*stack.Model
}

Discriminator is a BERT Discriminator model.

func NewDiscriminator

func NewDiscriminator(config DiscriminatorConfig) *Discriminator

NewDiscriminator returns a new BERT Discriminator model.

func (*Discriminator) Discriminate added in v0.2.0

func (m *Discriminator) Discriminate(encoded []ag.Node) []int

Discriminate returns 0 or 1 for each encoded element, where 1 means that the word is out of context.

type DiscriminatorConfig

type DiscriminatorConfig struct {
	InputSize        int
	HiddenSize       int
	HiddenActivation ag.OpName
	OutputActivation ag.OpName
}

DiscriminatorConfig provides configuration settings for a BERT Discriminator.

type Embeddings

type Embeddings struct {
	nn.BaseModel
	EmbeddingsConfig
	Words            *embeddings.Model
	Position         []nn.Param `spago:"type:weights"` // TODO: stop auto-wrapping
	TokenType        []nn.Param `spago:"type:weights"`
	Norm             *layernorm.Model
	Projector        *linear.Model
	UnknownEmbedding ag.Node `spago:"scope:processor"`
}

Embeddings is a BERT Embeddings model.

func NewEmbeddings

func NewEmbeddings(config EmbeddingsConfig) *Embeddings

NewEmbeddings returns a new BERT Embeddings model.

func (*Embeddings) Encode added in v0.2.0

func (m *Embeddings) Encode(words []string) []ag.Node

Encode transforms a string sequence into an encoded representation.

func (*Embeddings) InitProcessor added in v0.2.0

func (m *Embeddings) InitProcessor()

InitProcessor initializes the unknown embeddings.

type EmbeddingsConfig

type EmbeddingsConfig struct {
	Size                int
	OutputSize          int
	MaxPositions        int
	TokenTypes          int
	WordsMapFilename    string
	WordsMapReadOnly    bool
	DeletePreEmbeddings bool
}

EmbeddingsConfig provides configuration settings for BERT Embeddings.

type EncodeResponse

type EncodeResponse struct {
	Data []mat.Float `json:"data"`
	// Took is the number of milliseconds it took the server to execute the request.
	Took int64 `json:"took"`
}

EncodeResponse is a JSON-serializable server response for BERT "encode" requests.

type Encoder

type Encoder struct {
	EncoderConfig
	*stack.Model
}

Encoder is a BERT Encoder model.

func NewAlbertEncoder

func NewAlbertEncoder(config EncoderConfig) *Encoder

NewAlbertEncoder returns a new variant of the BERT encoder model. In this variant the stack of N identical BERT encoder layers share the same parameters.

func NewBertEncoder

func NewBertEncoder(config EncoderConfig) *Encoder

NewBertEncoder returns a new BERT encoder model composed of a stack of N identical BERT encoder layers.

type EncoderConfig

type EncoderConfig struct {
	Size                   int
	NumOfAttentionHeads    int
	IntermediateSize       int
	IntermediateActivation ag.OpName
	NumOfLayers            int
}

EncoderConfig provides configuration parameters for BERT Encoder. TODO: include and use the dropout hyper-parameter

type EncoderLayer

type EncoderLayer struct {
	nn.BaseModel
	MultiHeadAttention *multiheadattention.Model
	NormAttention      *layernorm.Model
	FFN                *stack.Model
	NormFFN            *layernorm.Model
	Index              int // layer index (useful for debugging)
}

EncoderLayer is a BERT Encoder Layer model.

func (*EncoderLayer) Forward added in v0.2.0

func (m *EncoderLayer) Forward(xs ...ag.Node) []ag.Node

Forward performs the forward step for each input node and returns the result.

type LabelerOptionsType

type LabelerOptionsType struct {
	MergeEntities     bool `json:"mergeEntities"`     // default false
	FilterNotEntities bool `json:"filterNotEntities"` // default false
}

LabelerOptionsType is a JSON-serializable set of options for BERT "tag" (labeler) requests.

type Model

type Model struct {
	nn.BaseModel
	Config          Config
	Vocabulary      *vocabulary.Vocabulary
	Embeddings      *Embeddings
	Encoder         *Encoder
	Predictor       *Predictor
	Discriminator   *Discriminator // used by "ELECTRA" training method
	Pooler          *Pooler
	SeqRelationship *linear.Model
	SpanClassifier  *SpanClassifier
	Classifier      *Classifier
}

Model implements a BERT model.

func LoadModel

func LoadModel(modelPath string) (*Model, error)

LoadModel loads a BERT Model from file.

func NewDefaultBERT

func NewDefaultBERT(config Config, embeddingsStoragePath string) *Model

NewDefaultBERT returns a new model based on the original BERT architecture.

func (*Model) Answer added in v0.6.0

func (m *Model) Answer(question string, passage string) Answers

Answer returns a slice of candidate answers for the given question-passage pair. The answers are sorted by confidence level in descending order.

func (*Model) Discriminate added in v0.2.0

func (m *Model) Discriminate(encoded []ag.Node) []int

Discriminate returns 0 or 1 for each encoded element, where 1 means that the word is out of context.

func (*Model) Encode added in v0.2.0

func (m *Model) Encode(tokens []string) []ag.Node

Encode transforms a string sequence into an encoded representation.

func (*Model) Pool added in v0.2.0

func (m *Model) Pool(transformed []ag.Node) ag.Node

Pool "pools" the model by simply taking the hidden state corresponding to the `[CLS]` token.

func (*Model) PredictMLM added in v0.6.0

func (m *Model) PredictMLM(text string) []Token

PredictMLM performs the Masked-Language-Model (MLM) prediction. It returns the best guess for the masked (i.e. `[MASK]`) tokens in the input text.

func (*Model) PredictMasked added in v0.2.0

func (m *Model) PredictMasked(transformed []ag.Node, masked []int) map[int]ag.Node

PredictMasked performs a masked prediction task. It returns the predictions for indices associated to the masked nodes.

func (*Model) PredictSeqRelationship added in v0.2.0

func (m *Model) PredictSeqRelationship(pooled ag.Node) ag.Node

PredictSeqRelationship predicts if the second sentence in the pair is the subsequent sentence in the original document.

func (*Model) SequenceClassification added in v0.2.0

func (m *Model) SequenceClassification(transformed []ag.Node) ag.Node

SequenceClassification performs a single sentence-level classification, using the pooled CLS token.

func (*Model) TokenClassification added in v0.2.0

func (m *Model) TokenClassification(transformed []ag.Node) []ag.Node

TokenClassification performs a classification for each element in the sequence.

func (*Model) Vectorize added in v0.6.0

func (m *Model) Vectorize(text string, poolingStrategy PoolingStrategy) (mat.Matrix, error)

Vectorize transforms the text into a dense vector representation.

type Pooler

type Pooler struct {
	*stack.Model
}

Pooler is a BERT Pooler model.

func NewPooler

func NewPooler(config PoolerConfig) *Pooler

NewPooler returns a new BERT Pooler model.

type PoolerConfig

type PoolerConfig struct {
	InputSize  int
	OutputSize int
}

PoolerConfig provides configuration settings for a BERT Pooler.

type PoolingStrategy added in v0.6.0

type PoolingStrategy int

PoolingStrategy defines the method to obtain the dense sentence representation

const (
	// ClsToken gets the encoding state corresponding to [CLS], i.e. the first token (default)
	ClsToken PoolingStrategy = iota
	// ReduceMean takes the average of the encoding states
	ReduceMean
	// ReduceMax takes the maximum of the encoding states
	ReduceMax
	// ReduceMeanMax does ReduceMean and ReduceMax separately and then concat them together
	ReduceMeanMax
)

type Predictor

type Predictor struct {
	*stack.Model
}

Predictor is a BERT Predictor model.

func NewPredictor

func NewPredictor(config PredictorConfig) *Predictor

NewPredictor returns a new BERT Predictor model.

func (*Predictor) PredictMasked added in v0.2.0

func (m *Predictor) PredictMasked(encoded []ag.Node, masked []int) map[int]ag.Node

PredictMasked performs a masked prediction task. It returns the predictions for indices associated to the masked nodes.

type PredictorConfig

type PredictorConfig struct {
	InputSize        int
	HiddenSize       int
	OutputSize       int
	HiddenActivation ag.OpName
	OutputActivation ag.OpName
}

PredictorConfig provides configuration settings for a BERT Predictor.

type QABody

type QABody struct {
	Question string `json:"question"`
	Passage  string `json:"passage"`
}

QABody is the JSON-serializable expected request body for BERT question-answering server requests.

type QuestionAnsweringResponse

type QuestionAnsweringResponse struct {
	Answers Answers `json:"answers"`
	// Took is the number of milliseconds it took the server to execute the request.
	Took int64 `json:"took"`
}

QuestionAnsweringResponse is the JSON-serializable structure for BERT question-answering server response.

type Response

type Response struct {
	Tokens []Token `json:"tokens"`
	// Took is the number of milliseconds it took the server to execute the request.
	Took int64 `json:"took"`
}

Response is the JSON-serializable server response for various BERT-related requests.

type Server

type Server struct {
	TimeoutSeconds  int
	MaxRequestBytes int

	// UnimplementedBERTServer must be embedded to have forward compatible implementations for gRPC.
	grpcapi.UnimplementedBERTServer
	// contains filtered or unexported fields
}

Server contains everything needed to run a BERT server.

func NewServer

func NewServer(model *Model) *Server

NewServer returns Server objects.

func (*Server) Answer

Answer handles a question-answering request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.

func (*Server) Classify

Classify handles a classification request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.

func (*Server) ClassifyHandler

func (s *Server) ClassifyHandler(w http.ResponseWriter, req *http.Request)

ClassifyHandler handles a classify request over HTTP.

func (*Server) Discriminate

Discriminate handles a discriminate request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.

func (*Server) DiscriminateHandler

func (s *Server) DiscriminateHandler(w http.ResponseWriter, req *http.Request)

DiscriminateHandler handles a discriminate request over HTTP.

func (*Server) Encode

Encode handles an encoding request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.

func (*Server) LabelerHandler

func (s *Server) LabelerHandler(w http.ResponseWriter, req *http.Request)

LabelerHandler handles a labeling request over HTTP.

func (*Server) Predict

Predict handles a predict request over gRPC. TODO(evanmcclure@gmail.com) Reuse the gRPC message type for HTTP requests.

func (*Server) PredictHandler

func (s *Server) PredictHandler(w http.ResponseWriter, req *http.Request)

PredictHandler handles a predict request over HTTP.

func (*Server) QaHandler

func (s *Server) QaHandler(w http.ResponseWriter, req *http.Request)

QaHandler is the HTTP server handler function for BERT question-answering requests.

func (*Server) SentenceEncoderHandler

func (s *Server) SentenceEncoderHandler(w http.ResponseWriter, req *http.Request)

SentenceEncoderHandler handles a sentence encoding request over HTTP.

func (*Server) StartDefaultServer

func (s *Server) StartDefaultServer(address, grpcAddress, tlsCert, tlsKey string, tlsDisable bool)

StartDefaultServer is used to start a basic BERT HTTP server. If you want more control of the HTTP server you can run your own HTTP router using the public handler functions

type SpanClassifier

type SpanClassifier struct {
	*linear.Model
}

SpanClassifier implements span classification for extractive question-answering tasks like SQuAD. It uses a linear layers to compute "span start logits" and "span end logits".

func NewSpanClassifier

func NewSpanClassifier(config SpanClassifierConfig) *SpanClassifier

NewSpanClassifier returns a new BERT SpanClassifier model.

func (*SpanClassifier) Classify added in v0.2.0

func (p *SpanClassifier) Classify(xs []ag.Node) (startLogits, endLogits []ag.Node)

Classify returns the "span start logits" and "span end logits".

type SpanClassifierConfig

type SpanClassifierConfig struct {
	InputSize int
}

SpanClassifierConfig provides configuration settings for a BERT SpanClassifier.

type Token

type Token struct {
	Text  string `json:"text"`
	Start int    `json:"start"`
	End   int    `json:"end"`
	Label string `json:"label"`
}

Token is a JSON-serializable labeled text token.

type TokenClassifierBody

type TokenClassifierBody struct {
	Options LabelerOptionsType `json:"options"`
	Text    string             `json:"text"`
}

TokenClassifierBody provides JSON-serializable parameters for BERT "tag" (labeler) requests.

type TokenSlice added in v0.5.0

type TokenSlice []Token

TokenSlice is a slice of Token elements, which implements the sort.Interface.

func (TokenSlice) Len added in v0.5.0

func (p TokenSlice) Len() int

Len returns the length of the slice.

func (TokenSlice) Less added in v0.5.0

func (p TokenSlice) Less(i, j int) bool

Less returns true if the Token.Start of the element at position i is lower than the one of the element at position j.

func (TokenSlice) Sort added in v0.5.0

func (p TokenSlice) Sort()

Sort sorts the TokenSlice's elements by Token.Start.

func (TokenSlice) Swap added in v0.5.0

func (p TokenSlice) Swap(i, j int)

Swap swaps the elements at positions i and j.

type Trainer

type Trainer struct {
	TrainingConfig
	// contains filtered or unexported fields
}

Trainer implements the training process for a BERT Model.

func NewTrainer

func NewTrainer(model *Model, config TrainingConfig) *Trainer

NewTrainer returns a new BERT Trainer.

func (*Trainer) Train

func (t *Trainer) Train()

Train executes the training process.

type TrainingConfig

type TrainingConfig struct {
	Seed             uint64
	BatchSize        int
	GradientClipping mat.Float
	UpdateMethod     gd.MethodConfig
	CorpusPath       string
	ModelPath        string
}

TrainingConfig provides configuration settings for a BERT Trainer.

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL