Source code for cobs.replay_buffer

import os, random, pickle
import numpy as np


[docs]class ReplayBuffer: def __init__(self, capacity=None, seed=None, chkpt_dir=None): """ ReplayBuffer is a buffer stores the previous states, corresponding actions, and rewards. :param capacity: Maximum number of history step the buffer should hold. :param seed: The random seed when sample from the buffer. :param chkpt_dir: The directory to dump the buffer. """ if chkpt_dir is not None and not os.path.exists(chkpt_dir): os.makedirs(chkpt_dir) self.chkpt_dir = chkpt_dir if seed is not None: random.seed(seed) self.capacity = capacity self.buffer = [] self.position = 0 self.ignore_set = set() self.save_set = set()
[docs] def push(self, prev_state, prev_action, current_state, done): """ Add one step into the buffer. :param prev_state: The state before performing an action. :param prev_action: The action current step performed. :param current_state: The resulting state after the action is being evaluated. :param done: The current_state is terminate state or not. :return: None """ if self.capacity is None or len(self.buffer) < self.capacity: self.buffer.append(None) if len(self.save_set) == 0: self.save_set = set(current_state.keys()) - self.ignore_set prev_state = {key: prev_state[key] for key in self.save_set} current_state = {key: current_state[key] for key in self.save_set} self.buffer[self.position] = (prev_state, prev_action, current_state, done) self.position += 1 if self.capacity: self.position %= self.capacity
[docs] def sample(self, batch_size): """ Sample some experiences from the buffer. :param batch_size: Number of steps to sample. :return: four lists for ``previous_state``, ``action``, ``resulting_state``, ``terminate``, respectively. """ batch = random.sample(self.buffer, batch_size) state, action, next_state, done = map(np.stack, zip(*batch)) return state, action, next_state, done
[docs] def save(self, num): """ Dump the current buffer to local disk. :param num: unique identifier to specify the version of the buffer dump. :return: None """ checkpoint_file = os.path.join(self.chkpt_dir, f'memory_{num}') dump_dict = { 'capacity': self.capacity, 'buffer': self.buffer, 'position': self.position } with open(checkpoint_file, "wb") as f: pickle.dump(dump_dict, f)
[docs] def load(self, num): """ Load a dumped buffer from the local disk. :param num: unique identifier to specify the version of the buffer dump. :return: None """ checkpoint_file = os.path.join(self.chkpt_dir, f'memory_{num}') with open(checkpoint_file, "rb") as f: dump_dict = pickle.load(f) self.capacity = dump_dict['capacity'] self.buffer = dump_dict['buffer'] self.position = dump_dict['position']
def __len__(self): return len(self.buffer)
[docs] def reset(self): """ Clear the buffer. :return: None """ self.buffer = list() self.position = 0 self.save_set = set() self.ignore_set = set()
[docs] def terminate(self): """ Set the last added state to terminate state. :return: None """ prev_buffer = list(self.buffer[self.position - 1]) prev_buffer[3] = True self.buffer[self.position - 1] = tuple(prev_buffer)
[docs] def set_ignore(self, ignore_set): """ Set which state features should not be stored to save memory space. :param ignore_set: A set contains all the key of state features that should be ignored. :return: None """ self.ignore_set = ignore_set