q

package
v0.0.0-...-225e849 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 22, 2020 License: Apache-2.0 Imports: 13 Imported by: 1

README

Q-learning

An implementation fo the Q-learning algorithm with adaptive learning.

How it works

In Q-learning the agent stores Q-values (quality values) for each state that it encounters. Q-values are determined by the following equation.

q-learning

Q-learning is an off-policy form of temporal difference. The Agent simply learns by storing a quality value for the state that it encountered and the reward that it recieved for the action taken along with the discounted future reward. Taking the future reward into account at each value iteration forms a Markov Chain which will converge to the highest reward.

An agent will explore or exploit the Q-values based on the epsilon hyperparameter.

The implemented agent also employs adaptive learning by which the alpha and epsilon hyperparameters are dynamically tuned based on the timestep and an ada divisor parameter.

Q-learning doesn't work well in continous environments, the pkg/v1/env package provides a normalization adapter. One of the adapters is for discretization and can be used to make continuous states discrete.

Examples

See the experiments folder for example implementations.

References

Documentation

Overview

Package q is an agent implementation of the Q learning algorithm.

Index

Constants

This section is empty.

Variables

View Source
var DefaultAgentConfig = &AgentConfig{
	Hyperparameters: DefaultHyperparameters,
	Base:            agentv1.NewBase("Q"),
}

DefaultAgentConfig is the default config for a dqn agent.

View Source
var DefaultHyperparameters = &Hyperparameters{
	Epsilon:    common.NewConstantSchedule(0.1),
	Gamma:      0.6,
	Alpha:      0.1,
	AdaDivisor: 5.0,
}

DefaultHyperparameters is the default agent configuration.

Functions

func HashState

func HashState(observations *tensor.Dense) uint32

HashState observations into an integer value. Note: this requires observations to always occur in the same order.

Types

type Agent

type Agent struct {
	*agentv1.Base
	*Hyperparameters
	// contains filtered or unexported fields
}

Agent that utilizes the Q-Learning algorithm.

func NewAgent

func NewAgent(c *AgentConfig, env *envv1.Env) *Agent

NewAgent returns a new Q-learning agent.

func (*Agent) Action

func (a *Agent) Action(state *tensor.Dense) (action int, err error)

Action returns the action that should be taken given the state hash.

func (*Agent) Adapt

func (a *Agent) Adapt(timestep int)

Adapt will adjust the hyperparameters based on th timestep.

func (*Agent) Learn

func (a *Agent) Learn(action int, state *tensor.Dense, outcome *envv1.Outcome) error

Learn using the Q-learning algorithm. Q(state,action)←(1−α)Q(state,action)+α(reward+γmaxaQ(next state,all actions))

func (*Agent) Visualize

func (a *Agent) Visualize()

Visualize the agents internal state.

type AgentConfig

type AgentConfig struct {
	// Base for the agent.
	Base *agentv1.Base

	// Hyperparameters for the agent.
	*Hyperparameters

	// Table for the agent.
	Table Table
}

AgentConfig is the config for a dqn agent.

type Hyperparameters

type Hyperparameters struct {
	// Epsilon is the rate at which the agent should explore vs exploit. The lower the value
	// the more exploitation.
	Epsilon common.Schedule

	// Gamma is the discount factor (0≤γ≤1). It determines how much importance we want to give to future
	// rewards. A high value for the discount factor (close to 1) captures the long-term effective award, whereas,
	// a discount factor of 0 makes our agent consider only immediate reward, hence making it greedy.
	Gamma float32

	// Alpha is the learning rate (0<α≤1). Just like in supervised learning settings, alpha is the extent
	// to which our Q-values are being updated in every iteration.
	Alpha float32

	// AdaDivisor is used in adaptive learning to tune the hyperparameters.
	AdaDivisor float32
}

Hyperparameters for a Q-learning agent.

type MemTable

type MemTable struct {
	// contains filtered or unexported fields
}

MemTable is an in memory Table with a row for every state, and a column for every action. State is held as a hash of observations.

func (*MemTable) Clear

func (m *MemTable) Clear() error

Clear the table.

func (*MemTable) Get

func (m *MemTable) Get(state uint32, action int) (float32, error)

Get the Q value for the given state and action.

func (*MemTable) GetMax

func (m *MemTable) GetMax(state uint32) (action int, qValue float32, err error)

GetMax returns the action with the max Q value for a given state hash.

func (*MemTable) Print

func (m *MemTable) Print()

Print the table with a pretty printer.

func (*MemTable) Set

func (m *MemTable) Set(state uint32, action int, qValue float32) error

Set the quality of the action taken for a given state.

type Table

type Table interface {
	// GetMax returns the action with the max Q value for a given state hash.
	GetMax(state uint32) (action int, qValue float32, err error)

	// Get the Q value for the given state and action.
	Get(state uint32, action int) (float32, error)

	// Set the q value of the action taken for a given state.
	Set(state uint32, action int, value float32) error

	// Clear the table.
	Clear() error

	// Pretty print the table.
	Print()
}

Table is the qualtiy table which stores the quality of an action by state.

func NewMemTable

func NewMemTable(actionSpaceSize int) Table

NewMemTable returns a new MemTable with the dimensions defined by the observation and action space sizes.

Directories

Path Synopsis
experiments

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL