replicate

package module
v0.13.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 4, 2023 License: Apache-2.0 Imports: 17 Imported by: 33

README

Replicate Go client

Go Reference

A Go client for Replicate. It lets you run models from your Go code, and everything else you can do with Replicate's HTTP API.

Requirements

  • Go 1.20+

Installation

Use go get to install the Replicate package:

go get -u github.com/replicate/replicate-go

Include the Replicate package in your project:

import "github.com/replicate/replicate-go"

Usage

import (
	"context"
	"os"

	"github.com/replicate/replicate-go"
)

// You can also provide a token directly with `replicate.NewClient(replicate.WithToken("r8_..."))`
client := replicate.NewClient(replicate.WithTokenFromEnv())

// https://replicate.com/stability-ai/stable-diffusion
version := "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"

input := replicate.PredictionInput{
  	"prompt": "an astronaut riding a horse on mars, hd, dramatic lighting",
}

webhook := replicate.Webhook{
  	URL:    "https://example.com/webhook",
  	Events: []replicate.WebhookEventType{"start", "completed"},
}

prediction, err := client.CreatePrediction(context.Background(), version, input, &webhook)

License

Replicate's Go client is released under the Apache 2.0 license. See LICENSE.txt

Documentation

Index

Constants

This section is empty.

Variables

View Source
var (
	ErrInvalidIdentifier = errors.New("invalid identifier, it must be in the format \"owner/name\" or \"owner/name:version\"")
)
View Source
var (
	ErrInvalidUTF8Data = errors.New("invalid UTF-8 data")
)
View Source
var (
	ErrNoAuth = errors.New(`no auth token or token source provided -- perhaps you forgot to pass replicate.WithToken("...")`)
)

Functions

func Paginate

func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T]) (<-chan []T, <-chan error)

Paginate takes a Page and the Client request method, and iterates through pages of results.

Types

type APIError

type APIError struct {
	Detail string `json:"detail"`
}

APIError represents an error returned by the Replicate API

func (APIError) Error

func (e APIError) Error() string

Error implements the error interface

type Backoff added in v0.7.0

type Backoff interface {
	NextDelay(retries int) time.Duration
}

Backoff is an interface for backoff strategies.

type Client

type Client struct {
	// contains filtered or unexported fields
}

Client is a client for the Replicate API.

func NewClient

func NewClient(opts ...ClientOption) (*Client, error)

NewClient creates a new Replicate API client.

func (*Client) CancelTraining

func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Training, error)

CancelTraining sends a request to the Replicate API to cancel a training.

func (*Client) CreateModel added in v0.11.0

func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName string, options CreateModelOptions) (*Model, error)

CreateModel creates a new model.

func (*Client) CreatePrediction

func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreatePrediction sends a request to the Replicate API to create a prediction.

func (*Client) CreatePredictionWithDeployment added in v0.9.0

func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deployment_owner string, deployment_name string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreateDeploymentPrediction sends a request to the Replicate API to create a prediction using the specified deployment.

func (*Client) CreatePredictionWithModel added in v0.13.0

func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner string, modelName string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreatePredictionWithModel sends a request to the Replicate API to create a prediction for a model.

func (*Client) CreateTraining

func (r *Client) CreateTraining(ctx context.Context, model_owner string, model_name string, version string, destination string, input TrainingInput, webhook *Webhook) (*Training, error)

CreateTraining sends a request to the Replicate API to create a new training.

func (*Client) GetCollection

func (r *Client) GetCollection(ctx context.Context, slug string) (*Collection, error)

GetCollection returns a collection by slug.

func (*Client) GetModel

func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error)

GetModel retrieves information about a model.

func (*Client) GetModelVersion

func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error)

GetModelVersion retrieves a specific version of a model.

func (*Client) GetPrediction

func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, error)

GetPrediction retrieves a prediction from the Replicate API by its ID.

func (*Client) GetTraining

func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error)

GetTraining sends a request to the Replicate API to get a training.

func (*Client) ListCollections

func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error)

ListCollections returns a list of all collections.

func (*Client) ListHardware added in v0.11.0

func (r *Client) ListHardware(ctx context.Context) (*[]Hardware, error)

ListHardware returns a list of available hardware.

func (*Client) ListModelVersions

func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error)

ListModelVersions lists the versions of a model.

func (*Client) ListModels added in v0.10.0

func (r *Client) ListModels(ctx context.Context) (*Page[Model], error)

ListModels lists public models.

func (*Client) ListPredictions

func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error)

ListPredictions returns a paginated list of predictions.

func (*Client) ListTrainings

func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error)

ListTrainings returns a list of trainings.

func (*Client) Run

func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error)

func (*Client) Stream added in v0.13.0

func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error)

func (*Client) StreamPrediction added in v0.13.1

func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (<-chan SSEEvent, <-chan error)

func (*Client) Wait

func (r *Client) Wait(ctx context.Context, prediction *Prediction, opts ...WaitOption) error

Wait for a prediction to finish.

This function blocks until the prediction has finished, or the context is canceled. If the prediction has already finished, the function returns immediately. If polling interval is less than or equal to zero, an error is returned.

func (*Client) WaitAsync added in v0.7.0

func (r *Client) WaitAsync(ctx context.Context, prediction *Prediction, opts ...WaitOption) (<-chan *Prediction, <-chan error)

WaitAsync returns a channel that receives the prediction as it progresses.

The channel is closed when the prediction has finished, or the context is canceled. If the prediction has already finished, the channel is closed immediately. If polling interval is less than or equal to zero, an error is sent to the error channel.

type ClientOption

type ClientOption func(*clientOptions) error

ClientOption is a function that modifies an options struct.

func WithBaseURL

func WithBaseURL(baseURL string) ClientOption

WithBaseURL sets the base URL for the client.

func WithHTTPClient

func WithHTTPClient(httpClient *http.Client) ClientOption

WithHTTPClient sets the HTTP client used by the client.

func WithRetryPolicy added in v0.7.0

func WithRetryPolicy(maxRetries int, backoff Backoff) ClientOption

WithRetryPolicy sets the retry policy used by the client.

func WithToken added in v0.6.0

func WithToken(token string) ClientOption

WithToken sets the auth token used by the client.

func WithTokenFromEnv added in v0.6.0

func WithTokenFromEnv() ClientOption

WithTokenFromEnv configures the client to use the auth token provided in the REPLICATE_API_TOKEN environment variable.

func WithUserAgent

func WithUserAgent(userAgent string) ClientOption

WithUserAgent sets the User-Agent header on requests made by the client.

type Collection

type Collection struct {
	Name        string   `json:"name"`
	Slug        string   `json:"slug"`
	Description string   `json:"description"`
	Models      *[]Model `json:"models,omitempty"`
	// contains filtered or unexported fields
}

func (Collection) MarshalJSON added in v0.8.1

func (c Collection) MarshalJSON() ([]byte, error)

func (*Collection) UnmarshalJSON added in v0.8.1

func (c *Collection) UnmarshalJSON(data []byte) error

type ConstantBackoff added in v0.7.0

type ConstantBackoff struct {
	Base   time.Duration
	Jitter time.Duration
}

ConstantBackoff is a backoff strategy that returns a constant delay with some jitter.

func (*ConstantBackoff) NextDelay added in v0.7.0

func (b *ConstantBackoff) NextDelay(retries int) time.Duration

NextDelay returns the next delay.

type CreateModelOptions added in v0.11.0

type CreateModelOptions struct {
	Visibility    string  `json:"visibility"`
	Hardware      string  `json:"hardware"`
	Description   *string `json:"description,omitempty"`
	GithubURL     *string `json:"github_url,omitempty"`
	PaperURL      *string `json:"paper_url,omitempty"`
	LicenseURL    *string `json:"license_url,omitempty"`
	CoverImageURL *string `json:"cover_image_url,omitempty"`
}

type ExponentialBackoff added in v0.7.0

type ExponentialBackoff struct {
	Base       time.Duration
	Multiplier float64
	Jitter     time.Duration
}

ExponentialBackoff is a backoff strategy that returns an exponentially increasing delay with some jitter.

func (*ExponentialBackoff) NextDelay added in v0.7.0

func (b *ExponentialBackoff) NextDelay(retries int) time.Duration

NextDelay returns the next delay.

type Hardware added in v0.11.0

type Hardware struct {
	SKU  string `json:"sku"`
	Name string `json:"name"`
	// contains filtered or unexported fields
}

func (Hardware) MarshalJSON added in v0.11.0

func (h Hardware) MarshalJSON() ([]byte, error)

func (*Hardware) UnmarshalJSON added in v0.11.0

func (h *Hardware) UnmarshalJSON(data []byte) error

type Identifier added in v0.13.0

type Identifier struct {
	// Owner is the username of the model owner.
	Owner string

	// Name is the name of the model.
	Name string

	// Version is the version of the model.
	Version *string
}

Identifier represents a reference to a Replicate model with an optional version.

func ParseIdentifier added in v0.13.0

func ParseIdentifier(identifier string) (*Identifier, error)

func (*Identifier) String added in v0.13.1

func (i *Identifier) String() string

type Model

type Model struct {
	URL            string        `json:"url"`
	Owner          string        `json:"owner"`
	Name           string        `json:"name"`
	Description    string        `json:"description"`
	Visibility     string        `json:"visibility"`
	GithubURL      string        `json:"github_url"`
	PaperURL       string        `json:"paper_url"`
	LicenseURL     string        `json:"license_url"`
	RunCount       int           `json:"run_count"`
	CoverImageURL  string        `json:"cover_image_url"`
	DefaultExample *Prediction   `json:"default_example"`
	LatestVersion  *ModelVersion `json:"latest_version"`
	// contains filtered or unexported fields
}

func (Model) MarshalJSON added in v0.8.1

func (m Model) MarshalJSON() ([]byte, error)

func (*Model) UnmarshalJSON added in v0.8.1

func (m *Model) UnmarshalJSON(data []byte) error

type ModelVersion

type ModelVersion struct {
	ID            string      `json:"id"`
	CreatedAt     string      `json:"created_at"`
	CogVersion    string      `json:"cog_version"`
	OpenAPISchema interface{} `json:"openapi_schema"`
	// contains filtered or unexported fields
}

func (ModelVersion) MarshalJSON added in v0.8.1

func (m ModelVersion) MarshalJSON() ([]byte, error)

func (*ModelVersion) UnmarshalJSON added in v0.8.1

func (m *ModelVersion) UnmarshalJSON(data []byte) error

type Page

type Page[T any] struct {
	Previous *string `json:"previous,omitempty"`
	Next     *string `json:"next,omitempty"`
	Results  []T     `json:"results"`
	// contains filtered or unexported fields
}

Page represents a paginated response from Replicate's API.

func (Page[T]) MarshalJSON added in v0.8.1

func (p Page[T]) MarshalJSON() ([]byte, error)

func (*Page[T]) UnmarshalJSON added in v0.8.1

func (p *Page[T]) UnmarshalJSON(data []byte) error

type Prediction

type Prediction struct {
	ID      string           `json:"id"`
	Status  Status           `json:"status"`
	Model   string           `json:"model"`
	Version string           `json:"version"`
	Input   PredictionInput  `json:"input"`
	Output  PredictionOutput `json:"output,omitempty"`
	Source  Source           `json:"source"`
	Error   interface{}      `json:"error,omitempty"`
	Logs    *string          `json:"logs,omitempty"`
	Metrics *struct {
		PredictTime *float64 `json:"predict_time,omitempty"`
	} `json:"metrics,omitempty"`
	Webhook             *string            `json:"webhook,omitempty"`
	WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"`
	URLs                map[string]string  `json:"urls,omitempty"`
	CreatedAt           string             `json:"created_at"`
	StartedAt           *string            `json:"started_at,omitempty"`
	CompletedAt         *string            `json:"completed_at,omitempty"`
	// contains filtered or unexported fields
}

func (Prediction) MarshalJSON added in v0.8.1

func (p Prediction) MarshalJSON() ([]byte, error)

func (Prediction) Progress added in v0.8.0

func (p Prediction) Progress() *PredictionProgress

func (*Prediction) UnmarshalJSON added in v0.8.1

func (p *Prediction) UnmarshalJSON(data []byte) error

type PredictionInput

type PredictionInput map[string]interface{}

type PredictionOutput

type PredictionOutput interface{}

type PredictionProgress added in v0.8.0

type PredictionProgress struct {
	Percentage float64
	Current    int
	Total      int
}

type SSEEvent added in v0.13.0

type SSEEvent struct {
	Type string
	ID   string
	Data string
}

SSEEvent represents a Server-Sent Event.

type Source

type Source string
const (
	SourceWeb Source = "web"
	SourceAPI Source = "api"
)

type Status

type Status string
const (
	Starting   Status = "starting"
	Processing Status = "processing"
	Succeeded  Status = "succeeded"
	Failed     Status = "failed"
	Canceled   Status = "canceled"
)

func (Status) String

func (s Status) String() string

func (Status) Terminated

func (s Status) Terminated() bool

type Training

type Training Prediction

type TrainingInput

type TrainingInput PredictionInput

type WaitOption added in v0.7.0

type WaitOption func(*waitOptions) error

WaitOption is a function that modifies an options struct.

func WithPollingInterval added in v0.7.0

func WithPollingInterval(interval time.Duration) WaitOption

WithPollingInterval sets the interval between attempts.

type Webhook

type Webhook struct {
	URL    string
	Events []WebhookEventType
}

type WebhookEventType

type WebhookEventType string
const (
	WebhookEventStart     WebhookEventType = "start"
	WebhookEventOutput    WebhookEventType = "output"
	WebhookEventLogs      WebhookEventType = "logs"
	WebhookEventCompleted WebhookEventType = "completed"
)

func (WebhookEventType) String

func (w WebhookEventType) String() string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL