rwkv

package module
v0.0.8 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 31, 2023 License: MIT Imports: 18 Imported by: 0

README

rwkv

pure go for rwkv and support cross-platform.

Go Reference

rwkv.go is a wrapper around rwkv.cpp, which is an adaption of ggml.cpp.

Installation

go get github.com/seasonjs/rwkv

AutoModel Compatibility

See deps folder for dylib compatibility,or you can build the library by yourself, and push request is welcome.

NewRwkvAutoModel both gpu support AMD and NVIDIA on Windows.

NewRwkvModel need you to load the dynamic library manually, and the dynamic library is platform dependent.

Windows AMD GPU User may need check rocm architecture to get more information.

Windows NVIDIA GPU User may need check cuda architecture to get more information.

platform x32 x64 arm AMD/ROCM NVIDIA/CUDA
windows not support support avx/avx2/avx512 not support rocm5.5 support cuda12 support
linux not support support not support not support not support
darwin not support support support not support not support

AutoModel Dynamic Libraries Disclaimer

The Source Of Dynamic Libraries

These dynamic libraries come from rwkv.cpp release, The dynamic library version can be obtained by viewing rwkv.version file Anyone can check the consistency of the file by checksum the md5 of the file.

The Security Of Dynamic Libraries

All I can say is that the creation of the dynamic library is public and does not contain any subjective malicious logic. If you are worried about the security of the dynamic library during the use process, you can build it yourself.

I and any author related to dynamic libraries do not assume any problems, responsibilities or legal liability during use.

Usage

You can find a complete example in examples folder.

Here is a simple example:

package main

import (
	"fmt"
	"github.com/seasonjs/rwkv"
)

func main() {
	model, err := rwkv.NewRwkvAutoModel(rwkv.RwkvOptions{
		MaxTokens:     500,
		StopString:    "\n\n",
		Temperature:   0.8,
		TopP:          0.5,
		TokenizerType: rwkv.World, //or World
		PrintError:    true,
		CpuThreads:    10,
		GpuEnable:     false,
	})

	if err != nil {
		fmt.Print(err.Error())
		return
	}

	defer model.Close()

	err = model.LoadFromFile("./models/RWKV-5-World-0.4B-v2-20231113-ctx4096-F16.bin")
	if err != nil {
		fmt.Print(err.Error())
		return
	}
	prompt := `The following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob.
Alice is very intelligent, creative and friendly.
Alice likes to tell Bob a lot about herself and her opinions.
Alice usually gives Bob kind, helpful and informative advices.

Bob: lhc
Alice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。
LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。

Bob: 企鹅会飞吗
Alice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。

`
	user := `Bob: 请介绍北京的旅游景点?
Alice: `

	ctx, err := model.InitState(prompt)

	if err != nil {
		print(err.Error())
		return
	}

	out, err := ctx.Predict(user)

	if err != nil {
		print(err.Error())
		return
	}

	print(out)
}

Packaging

To ship a working program that includes this AI, you will need to include the following files:

  • librwkv.dylib / librwkv.so / rwkv.dll (buildin)
  • the model file
  • the tokenizer file (buildin)

Low level API

This package also provide low level Api which is same as rwkv-cpp. See detail at rwkv-doc.

Thanks

Sponsor

Special thanks to JetBrains support for sponsoring.

JetBrains Logo (Main) logo

License

Copyright (c) seasonjs. All rights reserved. Licensed under the MIT License. See License.txt in the project root for license information.

Documentation

Index

Constants

View Source
const (
	RwkvErrorArgs        RwkvErrors = 1 << 8
	RwkvErrorFile                   = 2 << 8
	RwkvErrorModel                  = 3 << 8
	RwkvErrorModelParams            = 4 << 8
	RwkvErrorGraph                  = 5 << 8
	RwkvErrorCtx                    = 6 << 8
)

Variables

This section is empty.

Functions

func SampleLogits

func SampleLogits(tensor []float32, temperature float32, topP float32, logitBias map[int]float32) (int, error)

Types

type CRwkv

type CRwkv interface {
	// RwkvSetPrintErrors Sets whether errors are automatically printed to stderr.
	// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
	// - ctx: the context to suppress error messages for.
	//   If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
	//   as well as the default for new context.
	// - print_errors: whether error messages should be automatically printed.
	RwkvSetPrintErrors(ctx *RwkvCtx, enable bool)

	// RwkvGetPrintErrors Gets whether errors are automatically printed to stderr.
	// - ctx: the context to retrieve the setting for, or NULL for the global setting.
	RwkvGetPrintErrors(ctx *RwkvCtx) bool

	// RwkvGetLastError Retrieves and clears the error flags.
	// - ctx: the context the retrieve the error for, or NULL for the global error.
	RwkvGetLastError(ctx *RwkvCtx) error

	// RwkvInitFromFile Loads the model from a file and prepares it for inference.
	// Returns NULL on any error.
	// - model_file_path: path to model file in ggml format.
	// - n_threads: count of threads to use, must be positive.
	RwkvInitFromFile(filePath string, threads uint32) *RwkvCtx

	// RwkvCloneContext Creates a new context from an existing one.
	// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
	// Each rwkv_context can have one eval running at a time.
	// Every rwkv_context must be freed using rwkv_free.
	// - ctx: context to be cloned.
	// - n_threads: count of threads to use, must be positive.
	RwkvCloneContext(ctx *RwkvCtx, threads uint32) *RwkvCtx

	// RwkvGpuOffloadLayers Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
	// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
	RwkvGpuOffloadLayers(ctx *RwkvCtx, nGpuLayers uint32) error

	// RwkvEval Evaluates the model for a single token.
	// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
	// Returns false on any error.
	// - token: next token index, in range 0 <= token < n_vocab.
	// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
	// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
	// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
	RwkvEval(ctx *RwkvCtx, token uint32, stateIn []float32, stateOut []float32, logitsOut []float32) error

	// RwkvEvalSequence Evaluates the model for a sequence of tokens.
	// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
	// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
	// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
	// Returns false on any error.
	// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
	// - sequence_len: number of tokens to read from the array.
	// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
	// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
	// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
	RwkvEvalSequence(ctx *RwkvCtx, token uint32, sequenceLen uint64, stateIn []float32, stateOut []float32, logitsOut []float32) error

	// RwkvGetNVocab Returns the number of tokens in the given model's vocabulary.
	// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
	RwkvGetNVocab(ctx *RwkvCtx) uint64

	// RwkvGetNEmbedding Returns the number of elements in the given model's embedding.
	// Useful for reading individual fields of a model's hidden state.
	RwkvGetNEmbedding(ctx *RwkvCtx) uint64

	// RwkvGetNLayer Returns the number of layers in the given model.
	// Useful for always offloading the entire model to GPU.
	RwkvGetNLayer(ctx *RwkvCtx) uint64

	// RwkvGetStateLength Returns the number of float elements in a complete state for the given model.
	// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
	RwkvGetStateLength(ctx *RwkvCtx) uint64

	// RwkvGetLogitsLength Returns the number of float elements in the logits output of a given model.
	// This is currently always identical to n_vocab.
	RwkvGetLogitsLength(ctx *RwkvCtx) uint64

	// RwkvInitState Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
	// Useful in cases where tracking the first call to these functions may be annoying or expensive.
	// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
	// - state: FP32 buffer of size rwkv_get_state_len() to initialize
	RwkvInitState(ctx *RwkvCtx, state []float32)

	// RwkvFree Frees all allocated memory and the context.
	// Does not need to be called on the same thread that created the rwkv_context.
	RwkvFree(ctx *RwkvCtx) error

	// RwkvQuantizeModelFile Quantizes FP32 or FP16 model to one of quantized formats.
	// Returns false on any error. Error messages would be printed to stderr.
	// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
	// - model_file_path_out: quantized model will be written here.
	// - format_name: must be one of available format names below.
	// Available format names:
	// - Q4_0
	// - Q4_1
	// - Q5_0
	// - Q5_1
	// - Q8_0
	RwkvQuantizeModelFile(ctx *RwkvCtx, in, out string, format QuantizedFormat) error

	// RwkvGetSystemInfoString Returns system information string.
	RwkvGetSystemInfoString() string
}

type CRwkvImpl

type CRwkvImpl struct {
	// contains filtered or unexported fields
}

func NewCRwkv

func NewCRwkv(libraryPath string) (*CRwkvImpl, error)

func (*CRwkvImpl) RwkvCloneContext

func (c *CRwkvImpl) RwkvCloneContext(ctx *RwkvCtx, threads uint32) *RwkvCtx

func (*CRwkvImpl) RwkvEval

func (c *CRwkvImpl) RwkvEval(ctx *RwkvCtx, token uint32, stateIn []float32, stateOut []float32, logitsOut []float32) error

func (*CRwkvImpl) RwkvEvalSequence

func (c *CRwkvImpl) RwkvEvalSequence(ctx *RwkvCtx, token uint32, sequenceLen uint64, stateIn []float32, stateOut []float32, logitsOut []float32) error

func (*CRwkvImpl) RwkvFree

func (c *CRwkvImpl) RwkvFree(ctx *RwkvCtx) error

func (*CRwkvImpl) RwkvGetLastError

func (c *CRwkvImpl) RwkvGetLastError(ctx *RwkvCtx) error

func (*CRwkvImpl) RwkvGetLogitsLength

func (c *CRwkvImpl) RwkvGetLogitsLength(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetNEmbedding

func (c *CRwkvImpl) RwkvGetNEmbedding(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetNLayer

func (c *CRwkvImpl) RwkvGetNLayer(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetNVocab

func (c *CRwkvImpl) RwkvGetNVocab(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetPrintErrors

func (c *CRwkvImpl) RwkvGetPrintErrors(ctx *RwkvCtx) bool

func (*CRwkvImpl) RwkvGetStateLength

func (c *CRwkvImpl) RwkvGetStateLength(ctx *RwkvCtx) uint64

func (*CRwkvImpl) RwkvGetSystemInfoString

func (c *CRwkvImpl) RwkvGetSystemInfoString() string

func (*CRwkvImpl) RwkvGpuOffloadLayers

func (c *CRwkvImpl) RwkvGpuOffloadLayers(ctx *RwkvCtx, nGpuLayers uint32) error

func (*CRwkvImpl) RwkvInitFromFile

func (c *CRwkvImpl) RwkvInitFromFile(filePath string, threads uint32) *RwkvCtx

func (*CRwkvImpl) RwkvInitState

func (c *CRwkvImpl) RwkvInitState(ctx *RwkvCtx, state []float32)

func (*CRwkvImpl) RwkvQuantizeModelFile

func (c *CRwkvImpl) RwkvQuantizeModelFile(ctx *RwkvCtx, in, out string, format QuantizedFormat) error

func (*CRwkvImpl) RwkvSetPrintErrors

func (c *CRwkvImpl) RwkvSetPrintErrors(ctx *RwkvCtx, enable bool)

type GpuType

type GpuType string

type NormalTokenizer

type NormalTokenizer struct {
	// contains filtered or unexported fields
}

func NewNormalTokenizer

func NewNormalTokenizer() (*NormalTokenizer, error)

func (*NormalTokenizer) Decode

func (t *NormalTokenizer) Decode(ids []int) string

func (*NormalTokenizer) Encode

func (t *NormalTokenizer) Encode(input string) ([]int, error)

type QuantizedFormat

type QuantizedFormat string
const (
	Q4_0 QuantizedFormat = "Q4_0"
	Q4_1 QuantizedFormat = "Q4_1"
	Q5_0 QuantizedFormat = "Q5_0"
	Q5_1 QuantizedFormat = "Q5_0"
	Q8_0 QuantizedFormat = "Q8_0"
)

type RwkvCtx

type RwkvCtx struct {
	// contains filtered or unexported fields
}

type RwkvErrors

type RwkvErrors uint32
const (
	RwkvErrorNone RwkvErrors = iota
	RwkvErrorAlloc
	RwkvErrorFileOpen
	RwkvErrorFileStat
	RwkvErrorFileRead
	RwkvErrorFileWrite
	RwkvErrorFileMagic
	RwkvErrorFileVersion
	RwkvErrorDataType
	RwkvErrorUnsupported
	RwkvErrorShape
	RwkvErrorDimension
	RwkvErrorKey
	RwkvErrorData
	RwkvErrorParamMissing
)

Represents an error encountered during a function call. These are flags, so an actual value might contain multiple errors.

func (RwkvErrors) Error

func (err RwkvErrors) Error() string

type RwkvModel

type RwkvModel struct {
	// contains filtered or unexported fields
}

func NewRwkvAutoModel

func NewRwkvAutoModel(options RwkvOptions) (*RwkvModel, error)

func NewRwkvModel

func NewRwkvModel(dylibPath string, options RwkvOptions) (*RwkvModel, error)

func (*RwkvModel) Close

func (m *RwkvModel) Close() error

func (*RwkvModel) InitState

func (m *RwkvModel) InitState(prompt ...string) (*RwkvState, error)

InitState give a new state for new chat context state

func (*RwkvModel) LoadFromFile

func (m *RwkvModel) LoadFromFile(path string) error

func (*RwkvModel) QuantizeModelFile

func (m *RwkvModel) QuantizeModelFile(in, out string, format QuantizedFormat) error

type RwkvOptions

type RwkvOptions struct {
	PrintError       bool
	MaxTokens        int
	StopString       string
	Temperature      float32
	TopP             float32
	TokenizerType    TokenizerType
	CpuThreads       uint32
	GpuEnable        bool
	GpuOffLoadLayers uint32
}

type RwkvState

type RwkvState struct {
	// contains filtered or unexported fields
}

func (*RwkvState) CleanState

func (s *RwkvState) CleanState(prompt ...string) (*RwkvState, error)

CleanState will clean old state and set new state for new chat context state

func (*RwkvState) GetEmbedding

func (s *RwkvState) GetEmbedding(input string, distill bool) ([]float32, error)

GetEmbedding give the model embedding. the embedding in rwkv is hidden state the len is n_emb*5*n_layer=46080. So if distillation is true, we split len to n_emb = 768

func (*RwkvState) LoadState

func (s *RwkvState) LoadState(state []float32) error

func (*RwkvState) Predict

func (s *RwkvState) Predict(input string) (string, error)

Predict give current chat a response

func (*RwkvState) PredictStream

func (s *RwkvState) PredictStream(input string, output chan string)

func (*RwkvState) SaveState

func (s *RwkvState) SaveState() ([]float32, error)

type Tokenizer

type Tokenizer interface {
	Encode(in string) ([]int, error)
	Decode(in []int) string
}

type TokenizerType

type TokenizerType uint8
const (
	Normal TokenizerType = iota
	World
)

type Trie

type Trie struct {
	Root *TrieNode
}

Trie represents the trie data structure

func (*Trie) Add

func (t *Trie) Add(val string)

Add inserts a key into the trie

func (*Trie) FindLongest

func (t *Trie) FindLongest(key []rune) string

FindLongest finds the longest match in the trie for the given key

type TrieNode

type TrieNode struct {
	// contains filtered or unexported fields
}

TrieNode represents a node in the trie

func NewTrieNode

func NewTrieNode() *TrieNode

NewTrieNode initializes a new trie node

type WorldTokenizer

type WorldTokenizer struct {
	IndexToToken map[int]string
	TokenToIndex map[string]int
	Trie         *Trie
}

WorldTokenizer represents a tokenizer for encoding and decoding bytes to tokens

func NewWorldTokenizer

func NewWorldTokenizer() (*WorldTokenizer, error)

NewWorldTokenizer initializes a new world tokenizer

func (*WorldTokenizer) Decode

func (wt *WorldTokenizer) Decode(tokens []int) string

Decode decodes tokens to a string

func (*WorldTokenizer) DecodeBytes

func (wt *WorldTokenizer) DecodeBytes(tokens []int) []rune

DecodeBytes decodes tokens to bytes

func (*WorldTokenizer) Encode

func (wt *WorldTokenizer) Encode(src string) ([]int, error)

Encode encodes a string to tokens

func (*WorldTokenizer) EncodeBytes

func (wt *WorldTokenizer) EncodeBytes(src []rune) ([]int, error)

EncodeBytes encodes bytes to tokens

Directories

Path Synopsis
deps

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL