replicate

package module
v0.0.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 5, 2023 License: Apache-2.0 Imports: 17 Imported by: 0

README

Replicate Go client

Go Reference

A Go client for Replicate. It lets you run models from your Go code, and everything else you can do with Replicate's HTTP API.

Requirements

  • Go 1.20+

Installation

Use go get to install the Replicate package:

go get -u github.com/1019tech/lib-backend

Include the Replicate package in your project:

import "github.com/1019tech/lib-backend"

Usage

import (
	"context"
	"os"

	"github.com/1019tech/lib-backend"
)

// You can also provide a token directly with `replicate.NewClient(replicate.WithToken("r8_..."))`
client := replicate.NewClient(replicate.WithTokenFromEnv())

// https://replicate.com/stability-ai/stable-diffusion
version := "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"

input := replicate.PredictionInput{
  	"prompt": "an astronaut riding a horse on mars, hd, dramatic lighting",
}

webhook := replicate.Webhook{
  	URL:    "https://example.com/webhook",
  	Events: []replicate.WebhookEventType{"start", "completed"},
}

prediction, err := client.CreatePrediction(context.Background(), version, input, &webhook)

License

Replicate's Go client is released under the Apache 2.0 license. See LICENSE.txt

Documentation

Index

Constants

View Source
const (
	// SSETypeDone is the type of SSEEvent that indicates the prediction is done. The Data field will contain an empty JSON object.
	SSETypeDone = "done"

	// SSETypeError is the type of SSEEvent that indicates an error occurred during the prediction. The Data field will contain JSON with the error.
	SSETypeError = "error"

	// SSETypeLogs is the type of SSEEvent that contains logs from the prediction.
	SSETypeLogs = "logs"

	// SSETypeOutput is the type of SSEEvent that contains output from the prediction.
	SSETypeOutput = "output"
)

Variables

View Source
var (
	ErrInvalidIdentifier = errors.New("invalid identifier, it must be in the format \"owner/name\" or \"owner/name:version\"")
)
View Source
var (
	ErrInvalidUTF8Data = errors.New("invalid UTF-8 data")
)
View Source
var (
	ErrNoAuth = errors.New(`no auth token or token source provided -- perhaps you forgot to pass replicate.WithToken("...")`)
)

Functions

func Paginate

func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T]) (<-chan []T, <-chan error)

Paginate takes a Page and the Client request method, and iterates through pages of results.

Types

type APIError

type APIError struct {
	// Type is a URI that identifies the error type.
	Type string `json:"type,omitempty"`

	// Title is a short human-readable summary of the error.
	Title string `json:"title,omitempty"`

	// Status is the HTTP status code.
	Status int `json:"status,omitempty"`

	// Detail is a human-readable explanation of the error.
	Detail string `json:"detail,omitempty"`

	// Instance is a URI that identifies the specific occurrence of the error.
	Instance string `json:"instance,omitempty"`
}

APIError represents an error returned by the Replicate API

func (APIError) Error

func (e APIError) Error() string

func (*APIError) WriteHTTPResponse

func (e *APIError) WriteHTTPResponse(w http.ResponseWriter)

type Backoff

type Backoff interface {
	NextDelay(retries int) time.Duration
}

Backoff is an interface for backoff strategies.

type Client

type Client struct {
	// contains filtered or unexported fields
}

Client is a client for the Replicate API.

func NewClient

func NewClient(opts ...ClientOption) (*Client, error)

NewClient creates a new Replicate API client.

func (*Client) CancelPrediction

func (r *Client) CancelPrediction(ctx context.Context, id string) (*Prediction, error)

CancelPrediction cancels a prediction.

func (*Client) CancelTraining

func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Training, error)

CancelTraining sends a request to the Replicate API to cancel a training.

func (*Client) CreateModel

func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName string, options CreateModelOptions) (*Model, error)

CreateModel creates a new model.

func (*Client) CreatePrediction

func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreatePrediction sends a request to the Replicate API to create a prediction.

func (*Client) CreatePredictionWithDeployment

func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deployment_owner string, deployment_name string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreateDeploymentPrediction sends a request to the Replicate API to create a prediction using the specified deployment.

func (*Client) CreatePredictionWithModel

func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner string, modelName string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreatePredictionWithModel sends a request to the Replicate API to create a prediction for a model.

func (*Client) CreateTraining

func (r *Client) CreateTraining(ctx context.Context, model_owner string, model_name string, version string, destination string, input TrainingInput, webhook *Webhook) (*Training, error)

CreateTraining sends a request to the Replicate API to create a new training.

func (*Client) GetCollection

func (r *Client) GetCollection(ctx context.Context, slug string) (*Collection, error)

GetCollection returns a collection by slug.

func (*Client) GetModel

func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error)

GetModel retrieves information about a model.

func (*Client) GetModelVersion

func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error)

GetModelVersion retrieves a specific version of a model.

func (*Client) GetPrediction

func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, error)

GetPrediction retrieves a prediction from the Replicate API by its ID.

func (*Client) GetTraining

func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error)

GetTraining sends a request to the Replicate API to get a training.

func (*Client) ListCollections

func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error)

ListCollections returns a list of all collections.

func (*Client) ListHardware

func (r *Client) ListHardware(ctx context.Context) (*[]Hardware, error)

ListHardware returns a list of available hardware.

func (*Client) ListModelVersions

func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error)

ListModelVersions lists the versions of a model.

func (*Client) ListModels

func (r *Client) ListModels(ctx context.Context) (*Page[Model], error)

ListModels lists public models.

func (*Client) ListPredictions

func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error)

ListPredictions returns a paginated list of predictions.

func (*Client) ListTrainings

func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error)

ListTrainings returns a list of trainings.

func (*Client) Run

func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error)

func (*Client) Stream

func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error)

func (*Client) StreamPrediction

func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (<-chan SSEEvent, <-chan error)

func (*Client) Wait

func (r *Client) Wait(ctx context.Context, prediction *Prediction, opts ...WaitOption) error

Wait for a prediction to finish.

This function blocks until the prediction has finished, or the context is canceled. If the prediction has already finished, the function returns immediately. If polling interval is less than or equal to zero, an error is returned.

func (*Client) WaitAsync

func (r *Client) WaitAsync(ctx context.Context, prediction *Prediction, opts ...WaitOption) (<-chan *Prediction, <-chan error)

WaitAsync returns a channel that receives the prediction as it progresses.

The channel is closed when the prediction has finished, or the context is canceled. If the prediction has already finished, the channel is closed immediately. If polling interval is less than or equal to zero, an error is sent to the error channel.

type ClientOption

type ClientOption func(*clientOptions) error

ClientOption is a function that modifies an options struct.

func WithBaseURL

func WithBaseURL(baseURL string) ClientOption

WithBaseURL sets the base URL for the client.

func WithHTTPClient

func WithHTTPClient(httpClient *http.Client) ClientOption

WithHTTPClient sets the HTTP client used by the client.

func WithRetryPolicy

func WithRetryPolicy(maxRetries int, backoff Backoff) ClientOption

WithRetryPolicy sets the retry policy used by the client.

func WithToken

func WithToken(token string) ClientOption

WithToken sets the auth token used by the client.

func WithTokenFromEnv

func WithTokenFromEnv() ClientOption

WithTokenFromEnv configures the client to use the auth token provided in the REPLICATE_API_TOKEN environment variable.

func WithUserAgent

func WithUserAgent(userAgent string) ClientOption

WithUserAgent sets the User-Agent header on requests made by the client.

type Collection

type Collection struct {
	Name        string   `json:"name"`
	Slug        string   `json:"slug"`
	Description string   `json:"description"`
	Models      *[]Model `json:"models,omitempty"`
	// contains filtered or unexported fields
}

func (Collection) MarshalJSON

func (c Collection) MarshalJSON() ([]byte, error)

func (*Collection) UnmarshalJSON

func (c *Collection) UnmarshalJSON(data []byte) error

type ConstantBackoff

type ConstantBackoff struct {
	Base   time.Duration
	Jitter time.Duration
}

ConstantBackoff is a backoff strategy that returns a constant delay with some jitter.

func (*ConstantBackoff) NextDelay

func (b *ConstantBackoff) NextDelay(retries int) time.Duration

NextDelay returns the next delay.

type CreateModelOptions

type CreateModelOptions struct {
	Visibility    string  `json:"visibility"`
	Hardware      string  `json:"hardware"`
	Description   *string `json:"description,omitempty"`
	GithubURL     *string `json:"github_url,omitempty"`
	PaperURL      *string `json:"paper_url,omitempty"`
	LicenseURL    *string `json:"license_url,omitempty"`
	CoverImageURL *string `json:"cover_image_url,omitempty"`
}

type ExponentialBackoff

type ExponentialBackoff struct {
	Base       time.Duration
	Multiplier float64
	Jitter     time.Duration
}

ExponentialBackoff is a backoff strategy that returns an exponentially increasing delay with some jitter.

func (*ExponentialBackoff) NextDelay

func (b *ExponentialBackoff) NextDelay(retries int) time.Duration

NextDelay returns the next delay.

type Hardware

type Hardware struct {
	SKU  string `json:"sku"`
	Name string `json:"name"`
	// contains filtered or unexported fields
}

func (Hardware) MarshalJSON

func (h Hardware) MarshalJSON() ([]byte, error)

func (*Hardware) UnmarshalJSON

func (h *Hardware) UnmarshalJSON(data []byte) error

type Identifier

type Identifier struct {
	// Owner is the username of the model owner.
	Owner string

	// Name is the name of the model.
	Name string

	// Version is the version of the model.
	Version *string
}

Identifier represents a reference to a Replicate model with an optional version.

func ParseIdentifier

func ParseIdentifier(identifier string) (*Identifier, error)

func (*Identifier) String

func (i *Identifier) String() string

type Model

type Model struct {
	URL            string        `json:"url"`
	Owner          string        `json:"owner"`
	Name           string        `json:"name"`
	Description    string        `json:"description"`
	Visibility     string        `json:"visibility"`
	GithubURL      string        `json:"github_url"`
	PaperURL       string        `json:"paper_url"`
	LicenseURL     string        `json:"license_url"`
	RunCount       int           `json:"run_count"`
	CoverImageURL  string        `json:"cover_image_url"`
	DefaultExample *Prediction   `json:"default_example"`
	LatestVersion  *ModelVersion `json:"latest_version"`
	// contains filtered or unexported fields
}

func (Model) MarshalJSON

func (m Model) MarshalJSON() ([]byte, error)

func (*Model) UnmarshalJSON

func (m *Model) UnmarshalJSON(data []byte) error

type ModelVersion

type ModelVersion struct {
	ID            string      `json:"id"`
	CreatedAt     string      `json:"created_at"`
	CogVersion    string      `json:"cog_version"`
	OpenAPISchema interface{} `json:"openapi_schema"`
	// contains filtered or unexported fields
}

func (ModelVersion) MarshalJSON

func (m ModelVersion) MarshalJSON() ([]byte, error)

func (*ModelVersion) UnmarshalJSON

func (m *ModelVersion) UnmarshalJSON(data []byte) error

type Page

type Page[T any] struct {
	Previous *string `json:"previous,omitempty"`
	Next     *string `json:"next,omitempty"`
	Results  []T     `json:"results"`
	// contains filtered or unexported fields
}

Page represents a paginated response from Replicate's API.

func (Page[T]) MarshalJSON

func (p Page[T]) MarshalJSON() ([]byte, error)

func (*Page[T]) UnmarshalJSON

func (p *Page[T]) UnmarshalJSON(data []byte) error

type Prediction

type Prediction struct {
	ID      string           `json:"id"`
	Status  Status           `json:"status"`
	Model   string           `json:"model"`
	Version string           `json:"version"`
	Input   PredictionInput  `json:"input"`
	Output  PredictionOutput `json:"output,omitempty"`
	Source  Source           `json:"source"`
	Error   interface{}      `json:"error,omitempty"`
	Logs    *string          `json:"logs,omitempty"`
	Metrics *struct {
		PredictTime *float64 `json:"predict_time,omitempty"`
	} `json:"metrics,omitempty"`
	Webhook             *string            `json:"webhook,omitempty"`
	WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"`
	URLs                map[string]string  `json:"urls,omitempty"`
	CreatedAt           string             `json:"created_at"`
	StartedAt           *string            `json:"started_at,omitempty"`
	CompletedAt         *string            `json:"completed_at,omitempty"`
	// contains filtered or unexported fields
}

func (Prediction) MarshalJSON

func (p Prediction) MarshalJSON() ([]byte, error)

func (Prediction) Progress

func (p Prediction) Progress() *PredictionProgress

func (*Prediction) UnmarshalJSON

func (p *Prediction) UnmarshalJSON(data []byte) error

type PredictionInput

type PredictionInput map[string]interface{}

type PredictionOutput

type PredictionOutput interface{}

type PredictionProgress

type PredictionProgress struct {
	Percentage float64
	Current    int
	Total      int
}

type SSEEvent

type SSEEvent struct {
	Type string
	ID   string
	Data string
}

SSEEvent represents a Server-Sent Event.

func (*SSEEvent) String

func (e *SSEEvent) String() string

type Source

type Source string
const (
	SourceWeb Source = "web"
	SourceAPI Source = "api"
)

type Status

type Status string
const (
	Starting   Status = "starting"
	Processing Status = "processing"
	Succeeded  Status = "succeeded"
	Failed     Status = "failed"
	Canceled   Status = "canceled"
)

func (Status) String

func (s Status) String() string

func (Status) Terminated

func (s Status) Terminated() bool

type Training

type Training Prediction

type TrainingInput

type TrainingInput PredictionInput

type WaitOption

type WaitOption func(*waitOptions) error

WaitOption is a function that modifies an options struct.

func WithPollingInterval

func WithPollingInterval(interval time.Duration) WaitOption

WithPollingInterval sets the interval between attempts.

type Webhook

type Webhook struct {
	URL    string
	Events []WebhookEventType
}

type WebhookEventType

type WebhookEventType string
const (
	WebhookEventStart     WebhookEventType = "start"
	WebhookEventOutput    WebhookEventType = "output"
	WebhookEventLogs      WebhookEventType = "logs"
	WebhookEventCompleted WebhookEventType = "completed"
)

func (WebhookEventType) String

func (w WebhookEventType) String() string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL