본문 바로가기

강화학습

8. [강화학습] DQN(카트폴)

반응형

 안녕하세요 '코딩 오페라'블로그를 운영하고 있는 저는 'Master.M'입니다.

오늘 알아볼 내용은 'DQN'입니다. 

 

 저의 경우 2016년 알파고와 이세돌의 바둑 대결로 인해 인공지능에 관심이 많이 생기기 시작했고 이후 구글의 딥마인드 팀에서 발표한 DQN논문 특히 아타리사의 '브레이트 아웃' 게임을 하는 것을 보고 많은 감명을 받아 '강화 학습'이라는 학문에 많은 관심을 갖게 되었습니다. 그래서 오늘부터는 강화 학습에 대해 차분히 정리를 해보도록 하겠습니다.

 

 저는 '파이썬과 케라스로 배우는 강화 학습'이라는 책을 읽으면서 독학을 하였습니다. 이 글은 이 책을 참고하여 제작합니다.(광고 아닙니다!!)

 

https://codingopera.tistory.com/27

 

7. [강화학습] 딥살사(Deep SARSA)

 안녕하세요 '코딩 오페라'블로그를 운영하고 있는 저는 'Master.M'입니다. 오늘 알아볼 내용은 '딥 살사(Deep SARSA)'입니다.  저의 경우 2016년 알파고와 이세돌의 바둑 대결로 인해 인공지능에 관

codingopera.tistory.com

 저번 시간에 저희는 온 폴리시 알고리즘인 SARSA를 인공신경망에 적용한 Deep SARSA에 대해 알아보았습니다. 혹시 안 보신 분들은 먼저 위 링크의 글을 읽어보시기 바랍니다. 그럼 오프 폴리시 알고리즘인 큐 러닝도 인공신경망에 적용이 가능할까요? 네 가능합니다. 그런데 한 가지 '경험 리플레이(Experience Replay)'라는 장치가 필요합니다. 

 

경험 리플레이(Experience Replay)

 경험 리플레이라는 아이디어는 에어전트가 환경에서 탐험하여 얻은 샘플(s, a, r, s')을 메모리에 저장한다는 것입니다. 샘플을 저장하는 메모리는 리플레이 메모리(Replay Memory)라고 합니다. 에어전트가 학습할 때는 리플레이 메모리에서 여러 개의 샘플을 무작위로 뽑아서 뽑은 샘플에 대해 인공신경망을 업데이트합니다. 이 과정을 매 타임 스텝마다 반복합니다. 이를 그림으로 나타내면 아래와 같습니다. 

 

이러한 경험 리플레이를 사용하면 아래와 같은 장점이 있습니다. 

 

1. 샘플 간의 시간적 상관관계가 없다.

 Deep SARSA 알고리즘과 같은 온폴리시 알고리즘을 사용하면 에이전트가 안 좋은 상황에 빠질 경우 이에 맞게 학습을 해버린다는 단점이 있습니다. 그러나 리플레이 메모리를 사용하면 샘플 간의 시간적 상관관계가 없어 이러한 일이 발생하지 않습니다. 

 

2. 학습이 안정적이다.

 샘플 하나로 인공신경망을 업데이트하는 것이 아니라 리플레이 메모리에서 추출한 여러 개의 샘플을 통해 인공신경망을 업데이트하므로 학습이 안정적입니다.

 

 

타깃 신경망(Target Network)

큐러닝을 통한 큐함수의 업데이트

 

타깃 네트워크를 이용한 DQN 오류함수 정의

 

 부트스트랩의 문제점은 업데이트의 목표가 되는 정답이 계속 변한다는 것입니다. 그런데 정답을 내는 인공신경망 자체도 계속 업데이트되면 부트스트랩의 문제점은 더 심해질 것입니다. 이를 방지하기 위해 정답을 만들어내는 인공신경망을 일정 시간 동안 유지합니다. 타깃 신경망을 따로 만들어 타깃 신경망에서 정답에 해당하는 값을 구합니다. 구한 정답을 통해 다른 인공신경망을 계속 학습시키며 타깃 신경망은 일정한 시간 간격마다 그 인공신경망으로 업데이트합니다. 위의 수식에서 타깃 신경망은 theta^(-)를 매개변수로, 인공신경망은 theta로 표현합니다. 

 

 

그러면 지금부터 Open Ai gym의 카트폴 게임을 이용하여 DQN 알고리즘을 실습해보도록 하겠습니다. 

우선 위와 같이 필요한 라이브러리들을 import 해줍니다. 

 

 다음으로는 DQN이라는 클래스를 만들어 인공신경망 모델을 만들어줍니다. 위의 경우 layer가 self.fc1, self.fc2, self.fc_out 이렇게 3개가 있습니다. 처음 두 layer의 출력은 각각 30개로 설정하였고 마지막은 output 이므로 action_size만큼 출력해줍니다. 즉 action을 골라줍니다. call 함수를 통해 state(=x)를 입력으로 넣고 출력으로 q를 받습니다. 

- kernel_initializer=RandomUniform(): 가중치를 랜덤하게 초기화

 

deque 함수를 이용해 길이가 2000개인 리플레이 메모리를 만들어줍니다. 

 

타깃 모델과 일반 모델을 만들어줍니다. 

 

Epsilon Greedy Policy를 이용하여 Action을 선택해 줍니다. 

 

메모리에서 미니 배치 크기만큼 무작위로 샘플을 추출해줍니다. 이를 다시 states, actions, rewards, next_states, dones 데이터로 분류해줍니다. 

- random.sample(): 무작위로 샘플

 

 마지막으로 train_model 함수를 이용해 모델을 훈련시켜줍니다. 여기서 중요한 함수들은 다음과 같습니다. 

with tf.GradientTape() as tape: 경사 기록 장치(이 안에서 수행되는 연산의 경사가 기록)

tape.watch(): gradient를 기록하기 위해

tf.reduce_sum: 모든 성분의 총합

tf.reduce_mean: 모든 성분의 평균

zip(A, B): A와 B원소들을 묶어준다

tf.one_hot: 각 라벨에 해당하는 인덱스를 해당 위치만 1로 매핑하고, 나머지는 0으로 매핑한 벡터로 표현하는 방법

 

 전체 코드는 아래 제 깃허브 링크에 있으니 직접 코드를 실행하실 분들은 참고하시기 바라보니

https://github.com/CodingOpera/RL/blob/main/2-cartpole/1-dqn2.py

 

GitHub - CodingOpera/RL

Contribute to CodingOpera/RL development by creating an account on GitHub.

github.com

 

  지금 까지 저희는 'DQN'에 대해 알아보았습니다. 도움이 되셨나요? 만약 되셨다면 구독 및 좋아요로 표현해 주시면 정말 많은 힘이 됩니다. 궁금한 사항 혹은 앞으로 다루어 주었으면 좋을 주제가 있으시면 댓글 남겨주시면 감사하겠습니다. 저는 '코딩 오페라'의 'Master.M'이었습니다. 감사합니다.

반응형

'강화학습' 카테고리의 다른 글

10. [강화학습] A2C  (0) 2022.04.26
9. [강화학습] REINFORCE  (0) 2022.04.26
7. [강화학습] 딥살사(Deep SARSA)  (0) 2022.04.04
6. [강화학습] 큐러닝(Q-learning)  (0) 2022.03.25
5. [강화학습] 살사(SARSA)  (0) 2022.03.25