0

I am very new to reinforcement learning and I am trying to make a model traverse a very small graph, but it does not seem to be learning anything. I tried to follow the DQN tutorial on https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html and then replace the environment with my own, but it doesn't seem to convert at all.

Is anyone willing to take a look at my code?

    def __init__(self, adjacency: np.array):
        super().__init__()
        self.action_space = spaces.Discrete(4,)
        self.observation_space = spaces.Discrete(4,)
        self._graph = nx.DiGraph(adjacency)
        self._state = 0
        self._previous_states = []
        self._num_steps = 0
        self._episode_ended = False

    def get_state(self):
        return self._state

    def _remap_move(self) -> int:
        return np.argmax(nx.linalg.adjacency_matrix(self._graph)[self._state, :])

    def _legal_move(self, action) -> bool:
        return bool(nx.linalg.adjacency_matrix(self._graph)[self._state, action])

    def _output_tuple(self, reward):
        return self._state, reward, self._episode_ended, _

                    
    def step(self, action):
        if self._episode_ended:
            self.reset()

        if self._num_steps < 50:
            if not self._legal_move(action):
                action = self._remap_move()

            self._num_steps += 1

            if self._state == action:
                return self._output_tuple(-1)

            self._state = action

            if action == 4:
                self._episode_ended = True
                return self._output_tuple(100)

            else:
                # if action in self._previous_states:
                #     return self._output_tuple(-1)
                # else:
                self._previous_states.append(action)
                return self._output_tuple(1)

        self._episode_ended = True
        return self._output_tuple(-1)


    def reset(self):
        self._state = 0
        self._episode_ended = False

    def render(self, mode="human"):
        pass

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

BATCH_SIZE = 32
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
n_actions = 4

adj = np.array([[0, 1, 0, 0], [0, 1, 1, 0], [1, 1, 1, 1], [1, 0, 1, 1]])
g = nx.DiGraph(adj)
env = GraphWalk(g)

class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(1, 100),
            nn.ReLU(),
            nn.Linear(100, n_actions)
        )
    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))

policy_net = DQN(1, n_actions).to(device)
target_net = DQN(1, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)

steps_done = 0


def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)

    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)

    non_final_next_states = torch.cat([s for s in batch.next_state
                                       if s is not None])

    state_batch = torch.cat(batch.state)

    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net



    state_action_values = policy_net(state_batch).gather(1, action_batch)


    # print(state_action_values)
    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    print(loss)
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # print(state)
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            # print(policy_net(state))
            return torch.tensor([[policy_net(state).argmax()]], device=device, dtype=torch.long)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations():
    plt.figure(2)
    plt.clf()
    rew = torch.tensor(rewards, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(rew.numpy())
    # Take 100 episode averages and plot them too
    if len(rew) >= 100:
        means = rew.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated

def get_cuda_state():
    return torch.Tensor([env.get_state()])


actions = []
rewards = []

num_episodes = 100
for i_episode in range(num_episodes):
    # Initialize the environment and state
    env.reset()

    last_state = get_cuda_state()
    current_state = get_cuda_state()

    state = current_state - last_state
    r = []
    for t in count():
        # Select and perform an action
        action = select_action(state)
        actions.append(action.item())
        _, reward, done, _ = env.step(action.item())
        r.append(reward)
        reward = torch.tensor([reward], device=device)

        # Observe new state
        last_state = current_state
        current_state = get_cuda_state()
        if not done:
            next_state = current_state - last_state
        else:
            next_state = None

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()
        if done:
            rewards.append(sum(r)/len(r))
            # plot_durations()
            break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

print('Complete')

SVJ
  • 1

0 Answers0