본문 바로가기

강화학습

7. [강화학습] 딥살사(Deep SARSA)

반응형

 안녕하세요 '코딩 오페라'블로그를 운영하고 있는 저는 'Master.M'입니다.

오늘 알아볼 내용은 '딥 살사(Deep SARSA)'입니다. 

 

 저의 경우 2016년 알파고와 이세돌의 바둑 대결로 인해 인공지능에 관심이 많이 생기기 시작했고 이후 구글의 딥마인드 팀에서 발표한 DQN논문 특히 아타리사의 '브레이트 아웃' 게임을 하는 것을 보고 많은 감명을 받아 '강화 학습'이라는 학문에 많은 관심을 갖게 되었습니다. 그래서 오늘부터는 강화 학습에 대해 차분히 정리를 해보도록 하겠습니다.

 

 저는 '파이썬과 케라스로 배우는 강화 학습'이라는 책을 읽으면서 독학을 하였습니다. 이 글은 이 책을 참고하여 제작합니다.(광고 아닙니다!!)

 

https://codingopera.tistory.com/25

 

5. [강화학습] 살사(SARSA)

 안녕하세요 '코딩 오페라'블로그를 운영하고 있는 저는 'Master.M'입니다. 오늘 알아볼 내용은 '살사(SARSA)'입니다.  저의 경우 2016년 알파고와 이세돌의 바둑 대결로 인해 인공지능에 관심이 많

codingopera.tistory.com

 

 저번시간에 저희는 살사(SARAS) 알고리즘에 대해 알아보았습니다. 궁금하시거나 알고 싶으신 분들은 위의 링크를 참고하시기 바랍니다. 

 

1. 몬테카를로, 살사, 큐 러닝의 한계

 위의 알고리즘은 모델 프리(model free) 알고리즘으로 환경에 대한 모델이 필요했던 다이내믹 프로그램에 비해 많은 발전을 하였지만 아직 계산 복잡도차원의 저주 문제를 해결하지 못합니다. 이러한 문제가 발생하는 이유는 큐 함수를 테이블의 형태로 모든 행동 상태에 대해 저장하고 업데이트를 하기 때문입니다. 

 

2. 근사함수를 통한 가치 함수의 매개변수화

근사화

 

 위 1번의 큐함수큐 함수 테이블의 문제는 큐 함수를 매개변수로 근사함으로써 해결할 수 있습니다. 어떠한 데이터를 있는 그대로 저장하는 것이 가장 정확하겠지만 이러한 방식은 비효율적입니다. 비록 부정확하더라도 근사치 함수로 나타내는 것이 더 효율적입니다. 위 그림은 실제 데이터 점들을 근사화하여 빨간 함수로 나타낸 것입니다. 예를 들어 위의 근사 함수가 아래와 같은 함수라고 가정해보겠습니다. 이렇게 데이터를 표현하면 4개의 매개변수(a, b, c, d)만으로 기존의 데이터를 대체할 수 있습니다. 

근사함수(Function Approximator)

기존의 데이터를 매개변수를 통해 근사하는 함수를 근사함수(Fuction Approximator)라고 합니다. 이러한 근사 함수에는 여러 가지가 있지만 현재 성능이 제일 좋은 인공신경망(Artificial Neural Network)을 많이 사용합니다. 때문에 강화 학습에서 큐 함수를 테이블이 아닌 인공신경망으로 업데이트를 하면 위의 계산 복잡도와 차원의 저주 문제를 해결할 수 있습니다. 

 

그리드 월드 게임

 

3. 딥살사(Deep SARSA)

 딥살사는 살사 알고리즘에 인공신경망을 적용하여 큐 함수를 업데이트하는 알고리즘입니다. 그리드 월드 게임을 이용해 예를 들어보겠습니다. 먼저 그리드 월드 게임은 위와 같이 가운데 삼각형(장애물) 3개가 화살표 방향으로 계속 움직이고 월드의 끝에 가면 다시 튕겨서 돌아오는 것을 계속 반복합니다. 사각형은 에이전트를, 원은 도착점을 가리킵니다. 에이전트가 장애물(삼각형)을 만나면 -1, 도착하면 +1의 보상을 받습니다. 

 이 문제를 해결하기 위해서는 우선 MDP(Markov Decision Process)를 정의 해야 합니다. 이를 모르시는 분들은 아래 링크의 제 글을 참고하시기 바랍니다.

https://codingopera.tistory.com/22?category=1063355 

 

2. 강화학습 MDP(Markov Decision Process)

 안녕하세요 '코딩 오페라'블로그를 운영하고 있는 저는 'Master.M'입니다. 오늘 알아볼 내용은 'MDP(Markov Decision Process)'입니다.  저의 경우 2016년 알파고와 이세돌의 바둑 대결로 인해 인공지능에

codingopera.tistory.com

 

 MDP는 다음과 같습니다.

- 에이전트에 대한 도착지점의 상대 위치 x, y

- 도착지점의 라벨

- 에이전트에 대한 정애물의 상대 위치 x, y

- 장애물의 라벨

- 장애물의 속도

 

우리가 공의 속도와 위치 를 보고 피하거나 발로 차듯이 에이전트도 장애물에 대한 상대적인 위치 및 속도를 알아야 피할 수 있으므로 위와 같이 설정을 해줬습니다. 

 

Deep SARSA에서는 인공신경망을 이용해 큐값을 업데이트한다고 했습니다. 이때 경사 하강법을 사용합니다. 경사 하강법을 사용해 인공신경망을 업데이트하려면 오차 함수를 정의해야 합니다. 아래의 SARSA수식에서 정답과 예측의 역할을 하는 수식은 다음과 같습니다.

SARSA 알고리즘

 

정답의 역할
예측의 역할

 

Deep SARSA 오류함수

이를 MSE(Mean Squared Error)수식으로 나타내면 위와 같습니다. 자 그럼 다시 문제로 돌아와서 이를 코드로 어떻게 구현했는지 살펴보겠습니다. 

 

우선 위와 같이 필요한 라이브러리들을 import 해줍니다. 

 

 다음으로는 DeepSARSA라는 클래스를 만들어 인공신경망 모델을 만들어줍니다. 위의 경우 layer가 self.fc1, self.fc2, self.fc_out 이렇게 3개가 있습니다. 처음 두 layer의 출력은 각각 30개로 설정하였고 마지막은 output 이므로 action_size만큼 출력해줍니다. 즉 action을 골라줍니다. call 함수를 통해 state(=x)를 입력으로 넣고 출력으로 q를 받습니다. 

 

Epsilon Greedy 방법을 사용해줍니다. 

 

 마지막으로 train_model 함수를 이용해 모델을 훈련시켜줍니다. 여기서 중요한 함수들은 다음과 같습니다. 

- with tf.GradientTape() as tape: 경사 기록 장치(이 안에서 수행되는 연산의 경사가 기록)

- tape.watch(): gradient를 기록하기위해

- tf.reduce_sum: 모든 성분의 총합

- tf.reduce_mean: 모든 성분의 평균

- zip(A, B): A와 B원소들을 묶어준다

- tf.one_hot: 각 라벨에 해당하는 인덱스를 해당 위치만 1로 매핑하고, 나머지는 0으로 매핑한 벡터로 표현하는 방법

 

 

 전체 코드는 아래 제 깃허브 링크에 있으니 직접 코드를 실행하실 분들은 참고하시기 바랍니다. deep_sarsa_agent2.py가 에이전트 코드이고, environment.py가 게임 환경 코드입니다. 둘 다 다운로드하으셔야 합니다. 

https://github.com/CodingOpera/RL/tree/main/6-deep-sarsa

 

GitHub - CodingOpera/RL

Contribute to CodingOpera/RL development by creating an account on GitHub.

github.com

 

  지금 까지 저희는 '딥 살사(Deep SARSA)'에 대해 알아보았습니다. 도움이 되셨나요? 만약 되셨다면 구독 및 좋아요로 표현해 주시면 정말 많은 힘이 됩니다. 궁금한 사항 혹은 앞으로 다루어 주었으면 좋을 주제가 있으시면 댓글 남겨주시면 감사하겠습니다. 저는 '코딩 오페라'의 'Master.M'이었습니다. 감사합니다.

 

반응형