replicate

package module
v0.5.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 31, 2023 License: Apache-2.0 Imports: 10 Imported by: 33

README

Replicate Go client

A Go client for Replicate. It lets you run models from your Go code, and everything else you can do with Replicate's HTTP API.

Requirements

  • Go 1.20+

Installation

Use go get to install the Replicate package:

go get -u github.com/replicate/replicate-go

Include the Replicate package in your project:

import "github.com/replicate/replicate-go"

Usage

import (
	"context"
	"os"

	"github.com/replicate/replicate-go"
)

client := replicate.NewClient(os.Getenv("REPLICATE_API_TOKEN"))

// https://replicate.com/stability-ai/stable-diffusion
version := "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"

input := replicate.PredictionInput{
  	"prompt": "an astronaut riding a horse on mars, hd, dramatic lighting",
}

webhook := replicate.Webhook{
  	URL:    "https://example.com/webhook",
  	Events: []replicate.WebhookEventType{"start", "completed"},
}

prediction, err := client.CreatePrediction(context.Background(), version, input, &webhook)

License

Replicate's Go client is released under the Apache 2.0 license. See LICENSE.txt

Documentation

Index

Constants

This section is empty.

Variables

Functions

func Paginate

func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T]) (<-chan []T, <-chan error)

Paginate takes a Page and the Client request method, and iterates through pages of results.

Types

type APIError

type APIError struct {
	Detail string `json:"detail"`
}

APIError represents an error returned by the Replicate API

func (APIError) Error

func (e APIError) Error() string

Error implements the error interface

type Client

type Client struct {
	Auth       string
	UserAgent  *string
	BaseURL    string
	HTTPClient *http.Client
}

Client is a client for the Replicate API.

func NewClient

func NewClient(auth string, options ...ClientOption) *Client

NewClient creates a new Replicate API client.

func (*Client) CancelTraining

func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Training, error)

CancelTraining sends a request to the Replicate API to cancel a training.

func (*Client) CreatePrediction

func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)

CreatePrediction sends a request to the Replicate API to create a prediction.

func (*Client) CreateTraining

func (r *Client) CreateTraining(ctx context.Context, model_owner string, model_name string, version string, destination string, input TrainingInput, webhook *Webhook) (*Training, error)

CreateTraining sends a request to the Replicate API to create a new training.

func (*Client) GetCollection

func (r *Client) GetCollection(ctx context.Context, slug string) (*Collection, error)

GetCollection returns a collection by slug.

func (*Client) GetModel

func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error)

GetModel retrieves information about a model.

func (*Client) GetModelVersion

func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error)

GetModelVersion retrieves a specific version of a model.

func (*Client) GetPrediction

func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, error)

GetPrediction retrieves a prediction from the Replicate API by its ID.

func (*Client) GetTraining

func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error)

GetTraining sends a request to the Replicate API to get a training.

func (*Client) ListCollections

func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error)

ListCollections returns a list of all collections.

func (*Client) ListModelVersions

func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error)

ListModelVersions lists the versions of a model.

func (*Client) ListPredictions

func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error)

ListPredictions returns a paginated list of predictions.

func (*Client) ListTrainings

func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error)

ListTrainings returns a list of trainings.

func (*Client) Run

func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error)

func (*Client) Wait

func (r *Client) Wait(ctx context.Context, prediction Prediction, interval time.Duration, maxAttempts int) (*Prediction, error)

Wait for a prediction to finish.

This function blocks until the prediction has finished, or the context is cancelled. If the prediction has already finished, the prediction is returned immediately. If the prediction has not finished after maxAttempts, an error is returned. If interval is less than or equal to zero, an error is returned. If maxAttempts is 0, there is no limit to the number of attempts. If maxAttempts is negative, an error is returned.

type ClientOption

type ClientOption func(*Client)

ClientOption is a function that modifies a Client.

func WithBaseURL

func WithBaseURL(baseURL string) ClientOption

WithBaseURL sets the base URL for the client.

func WithHTTPClient

func WithHTTPClient(httpClient *http.Client) ClientOption

WithHTTPClient sets the HTTP client used by the client.

func WithUserAgent

func WithUserAgent(userAgent string) ClientOption

WithUserAgent sets the User-Agent header on requests made by the client.

type Collection

type Collection struct {
	Name        string   `json:"name"`
	Slug        string   `json:"slug"`
	Description string   `json:"description"`
	Models      *[]Model `json:"models,omitempty"`
}

type Model

type Model struct {
	URL            string        `json:"url"`
	Owner          string        `json:"owner"`
	Name           string        `json:"name"`
	Description    string        `json:"description"`
	Visibility     string        `json:"visibility"`
	GithubURL      string        `json:"github_url"`
	PaperURL       string        `json:"paper_url"`
	LicenseURL     string        `json:"license_url"`
	RunCount       int           `json:"run_count"`
	CoverImageURL  string        `json:"cover_image_url"`
	DefaultExample *Prediction   `json:"default_example"`
	LatestVersion  *ModelVersion `json:"latest_version"`
}

type ModelVersion

type ModelVersion struct {
	ID            string      `json:"id"`
	CreatedAt     string      `json:"created_at"`
	CogVersion    string      `json:"cog_version"`
	OpenAPISchema interface{} `json:"openapi_schema"`
}

type Page

type Page[T any] struct {
	Previous *string `json:"previous,omitempty"`
	Next     *string `json:"next,omitempty"`
	Results  []T     `json:"results"`
}

Page represents a paginated response from Replicate's API.

type Prediction

type Prediction struct {
	ID      string           `json:"id"`
	Status  Status           `json:"status"`
	Version string           `json:"version"`
	Input   PredictionInput  `json:"input"`
	Output  PredictionOutput `json:"output,omitempty"`
	Source  Source           `json:"source"`
	Error   interface{}      `json:"error,omitempty"`
	Logs    *string          `json:"logs,omitempty"`
	Metrics *struct {
		PredictTime *float64 `json:"predict_time,omitempty"`
	} `json:"metrics,omitempty"`
	Webhook             *string            `json:"webhook,omitempty"`
	WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"`
	URLs                map[string]string  `json:"urls,omitempty"`
	CreatedAt           string             `json:"created_at"`
	StartedAt           *string            `json:"started_at,omitempty"`
	CompletedAt         *string            `json:"completed_at,omitempty"`
}

type PredictionInput

type PredictionInput map[string]interface{}

type PredictionOutput

type PredictionOutput interface{}

type Source

type Source string
const (
	SourceWeb Source = "web"
	SourceAPI Source = "api"
)

type Status

type Status string
const (
	Starting   Status = "starting"
	Processing Status = "processing"
	Succeeded  Status = "succeeded"
	Failed     Status = "failed"
	Canceled   Status = "canceled"
)

func (Status) String

func (s Status) String() string

func (Status) Terminated

func (s Status) Terminated() bool

type Training

type Training Prediction

type TrainingInput

type TrainingInput PredictionInput

type Webhook

type Webhook struct {
	URL    string
	Events []WebhookEventType
}

type WebhookEventType

type WebhookEventType string
const (
	WebhookEventStart     WebhookEventType = "start"
	WebhookEventOutput    WebhookEventType = "output"
	WebhookEventLogs      WebhookEventType = "logs"
	WebhookEventCompleted WebhookEventType = "completed"
)

func (WebhookEventType) String

func (w WebhookEventType) String() string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL