replicateapi

package module
v0.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 4, 2023 License: Apache-2.0 Imports: 11 Imported by: 0

README

ReplicateAPI

A dead simple API wrapper around the replicate API.

Quick start

go get github.com/StyleSpaceAI/replicateapi@latest
import "github.com/StyleSpaceAI/replicateapi"

const (
    token = "getYourTokenFromReplicateProfile"
    MODELNAME = "stability-ai/stable-diffusion"
)

func main() {
	// Initialize a new API client
	cli, err := replicateapi.NewClient(token, MODELNAME, "")
	if err != nil {
		log.Fatal("init client", err)
	}

	// Fetch all the available versions for this model
	vers, err := cli.GetModelVersions(context.Background())
	if err != nil {
		log.Fatal("fetch versions", err)
	}

	// Picking the latest version of the model
	cli.Version = vers[0].ID

	// Register an asynchronous prediction task
	result, err := cli.CreatePrediction(context.Background(), map[string]interface{}{
		"prompt": "putin sucks huge cock, 4k",
	})
	if err != nil {
		log.Fatal("create prediction", err)
	}

	// The response of the API is async, so we need to wait for the response
	for keepChecking := true; keepChecking; {
		time.Sleep(time.Second * 3)

		// Fetch status and results of existnig prediction
		result, err = cli.GetResult(context.Background(), result.ID)
		if err != nil {
			log.Fatal("fetch prediction result", err)
		}

		switch result.Status {
		case replicateapi.PredictionStatusSucceeded, replicateapi.PredictionStatusCanceled, replicateapi.PredictionStatusFailed:
			// Final statuses
			keepChecking = false
		case replicateapi.PredictionStatusProcessing, replicateapi.PredictionStatusStarting:
			// Still processing
		}
	}
	fmt.Printf("%+v\n", result)
}

Documentation

Index

Constants

View Source
const (
	// PredictionStatusStarting - the prediction is starting up. If this status lasts longer than a few seconds, then it's typically because a new worker is being started to run the prediction.
	PredictionStatusStarting = "starting"
	// PredictionStatusProcessing - the predict() method of the model is currently running.
	PredictionStatusProcessing = "processing"
	// PredictionStatusSucceeded - the prediction completed successfully.
	PredictionStatusSucceeded = "succeeded"
	// PredictionStatusFailed - the prediction encountered an error during processing.
	PredictionStatusFailed = "failed"
	// PredictionStatusCanceled - the prediction was canceled by the user.
	PredictionStatusCanceled = "canceled"
)

Variables

View Source
var (
	// URI of the replicate API
	URI = "https://api.replicate.com"
	// Version of the replicate API
	Version = "v1"
)
View Source
var (
	// ErrUnauthorized you should check your authorization token and availability of the model
	ErrUnauthorized = errors.New("unauthorized")
	// ErrRateLimitReached check the official docs regarding the current limits https://replicate.com/docs/reference/http#rate-limits
	ErrRateLimitReached = errors.New("rate limit reached")
)

Functions

func EncodeImage

func EncodeImage(image []byte) (string, error)

EncodeImage into the format accepted by replicate APIs

Types

type Client

type Client struct {
	AuthorizationToken string
	Owner              string
	Model              string
	Version            string

	HTTPClient *http.Client
}

Client for the replicate.com api. Use NewClient for smooth initialization

func NewClient

func NewClient(token, model, version string) (*Client, error)

NewClient creates a new API client

func (*Client) CreatePrediction

func (c *Client) CreatePrediction(ctx context.Context, input map[string]interface{}) (*PredictionResult, error)

CreatePrediction will register an asynchronous prediction request with the replicate API

func (*Client) GetModelVersions

func (c *Client) GetModelVersions(ctx context.Context) ([]*ModelVersion, error)

GetModelVersions will return the list of versions available for the model set in the client All the versions are sorted by the creation time

func (*Client) GetResult

func (c *Client) GetResult(ctx context.Context, predictionID string) (*PredictionResult, error)

GetResult of a prediction by its ID

type ModelVersion

type ModelVersion struct {
	ID            string                 `json:"id"`
	CreatedAt     time.Time              `json:"created_at"`
	CogVersion    string                 `json:"cog_version"`
	OpenapiSchema map[string]interface{} `json:"openapi_schema"`
}

ModelVersion represents a single version of the model with the related schema

type PredictionResult

type PredictionResult struct {
	ID      string `json:"id"`
	Version string `json:"version"`
	Urls    struct {
		Get    string `json:"get"`
		Cancel string `json:"cancel"`
	} `json:"urls"`
	CreatedAt   time.Time              `json:"created_at"`
	StartedAt   any                    `json:"started_at"`
	CompletedAt any                    `json:"completed_at"`
	Status      PredictionStatus       `json:"status"`
	Input       map[string]interface{} `json:"input"`
	Output      any                    `json:"output"`
	Error       any                    `json:"error"`
	Logs        any                    `json:"logs"`
	Metrics     map[string]interface{} `json:"metrics"`
}

PredictionResult is a represenation of a single prediction from the replicate API

func (*PredictionResult) Refresh

func (p *PredictionResult) Refresh(ctx context.Context, c *Client) error

Refresh the status of the prediction inplace

type PredictionStatus

type PredictionStatus = string

PredictionStatus is returned from replicate API

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL