deepq

package
v0.1.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2023 License: Apache-2.0 Imports: 14 Imported by: 0

README

Deep Q-learning

Implementation of the DeepQ algorithm with Double Q.

How it works

DeepQ is an progression on standard Q-learning.

q-learning

With DeepQ, rather than storing Q-values in a table, they are aprroximated using neural networks. This allows for more accurate Q-value estimates as well as the ability to model continuous states.

DeepQ also includes the notion of experience replay, in which the agent stores the states, actions, and outcomes at every step in memory and then randomly samples from them during training.

Double-Q is further implemented in which the target, or expected future rewards, is modeled in a separate network having the weights intermittently copied over from the 'online' network making the predictions. This helps learning by providing a more stable target to pursue.

Examples

See the experiments folder for example implementations.

Roadmap

  • Prioritized replay
  • Dueling Q
  • Soft updates
  • More environments

References

Documentation

Overview

Package deepq is an agent implementation of the DeepQ algorithm.

Index

Constants

This section is empty.

Variables

View Source
var DefaultAgentConfig = &AgentConfig{
	Hyperparameters: DefaultHyperparameters,
	PolicyConfig:    DefaultPolicyConfig,
	Base:            agentv1.NewBase("DeepQ"),
}

DefaultAgentConfig is the default config for a dqn agent.

View Source
var DefaultAtariLayerBuilder = func(x, y *modelv1.Input) []layer.Config {
	return []layer.Config{
		layer.Conv2D{Input: 1, Output: 32, Width: 8, Height: 8, Stride: []int{4, 4}},
		layer.Conv2D{Input: 32, Output: 64, Width: 4, Height: 4, Stride: []int{2, 2}},
		layer.Conv2D{Input: 64, Output: 64, Width: 3, Height: 3, Stride: []int{1, 1}},
		layer.Flatten{},
		layer.FC{Input: 6400, Output: 512},
		layer.FC{Input: 512, Output: y.Squeeze()[0], Activation: layer.Linear},
	}
}

DefaultAtariLayerBuilder is the default layer builder for atari environments.

View Source
var DefaultAtariPolicyConfig = &PolicyConfig{
	Loss:         modelv1.MSE,
	Optimizer:    g.NewRMSPropSolver(g.WithBatchSize(20)),
	LayerBuilder: DefaultAtariLayerBuilder,
	BatchSize:    20,
	Track:        true,
}

DefaultAtariPolicyConfig is the default policy config for atari environments.

View Source
var DefaultFCLayerBuilder = func(x, y *modelv1.Input) []layer.Config {
	return []layer.Config{
		layer.FC{Input: x.Squeeze()[0], Output: 24},
		layer.FC{Input: 24, Output: 24},
		layer.FC{Input: 24, Output: y.Squeeze()[0], Activation: layer.Linear},
	}
}

DefaultFCLayerBuilder is a default fully connected layer builder.

View Source
var DefaultHyperparameters = &Hyperparameters{
	Epsilon:           common.DefaultDecaySchedule(),
	Gamma:             0.95,
	UpdateTargetSteps: 100,
	BufferSize:        10e6,
}

DefaultHyperparameters are the default hyperparameters.

View Source
var DefaultPolicyConfig = &PolicyConfig{
	Loss:         modelv1.MSE,
	Optimizer:    g.NewAdamSolver(g.WithLearnRate(0.0005)),
	LayerBuilder: DefaultFCLayerBuilder,
	BatchSize:    20,
	Track:        true,
}

DefaultPolicyConfig are the default hyperparameters for a policy.

Functions

func MakePolicy

func MakePolicy(name string, config *PolicyConfig, base *agentv1.Base, env *envv1.Env) (modelv1.Model, error)

MakePolicy makes a policy model.

Types

type Agent

type Agent struct {
	// Base for the agent.
	*agentv1.Base

	// Hyperparameters for the dqn agent.
	*Hyperparameters

	// Policy for the agent.
	Policy model.Model

	// Target policy for double Q learning.
	TargetPolicy model.Model

	// Epsilon is the rate at which the agent explores vs exploits.
	Epsilon common.Schedule
	// contains filtered or unexported fields
}

Agent is a dqn agent.

func NewAgent

func NewAgent(c *AgentConfig, env *envv1.Env) (*Agent, error)

NewAgent returns a new dqn agent.

func (*Agent) Action

func (a *Agent) Action(state *tensor.Dense) (action int, err error)

Action selects the best known action for the given state.

func (*Agent) Learn

func (a *Agent) Learn() error

Learn the agent.

func (*Agent) Remember

func (a *Agent) Remember(event *Event)

Remember an event.

type AgentConfig

type AgentConfig struct {
	// Base for the agent.
	Base *agentv1.Base

	// Hyperparameters for the agent.
	*Hyperparameters

	// PolicyConfig for the agent.
	PolicyConfig *PolicyConfig
}

AgentConfig is the config for a dqn agent.

type Event

type Event struct {
	*envv1.Outcome

	// State by which the action was taken.
	State *tensor.Dense

	// Action that was taken.
	Action int
	// contains filtered or unexported fields
}

Event is an event that occurred.

func NewEvent

func NewEvent(state *tensor.Dense, action int, outcome *envv1.Outcome) *Event

NewEvent returns a new event

func (*Event) Print

func (e *Event) Print()

Print the event.

type Hyperparameters

type Hyperparameters struct {
	// Gamma is the discount factor (0≤γ≤1). It determines how much importance we want to give to future
	// rewards. A high value for the discount factor (close to 1) captures the long-term effective award, whereas,
	// a discount factor of 0 makes our agent consider only immediate reward, hence making it greedy.
	Gamma float32

	// Epsilon is the rate at which the agent should exploit vs explore.
	Epsilon common.Schedule

	// UpdateTargetSteps determines how often the target network updates its parameters.
	UpdateTargetSteps int

	// BuferSize is the buffer size of the memory.
	BufferSize int
}

Hyperparameters for the dqn agent.

type LayerBuilder

type LayerBuilder func(x, y *modelv1.Input) []layer.Config

LayerBuilder builds layers.

type Memory

type Memory struct {
	*deque.Deque
}

Memory for the dqn agent.

func NewMemory

func NewMemory() *Memory

NewMemory returns a new Memory store.

func (*Memory) Sample

func (m *Memory) Sample(batchsize int) ([]*Event, error)

Sample from the memory with the given batch size.

type PolicyConfig

type PolicyConfig struct {
	// Loss function to evaluate network performance.
	Loss modelv1.Loss

	// Optimizer to optimize the weights with regards to the error.
	Optimizer g.Solver

	// LayerBuilder is a builder of layer.
	LayerBuilder LayerBuilder

	// BatchSize of the updates.
	BatchSize int

	// Track is whether to track the model.
	Track bool
}

PolicyConfig are the hyperparameters for a policy.

Directories

Path Synopsis
experiments

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL