trttung1610 commited on
Commit
849fda6
1 Parent(s): 2952ae4

Create objects.py

Browse files
Files changed (1) hide show
  1. objects.py +290 -0
objects.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+
4
+ BOARD_ROWS = 3
5
+ BOARD_COLS = 3
6
+
7
+
8
+ class State:
9
+ def __init__(self, p1, p2):
10
+ self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
11
+ self.p1 = p1
12
+ self.p2 = p2
13
+ self.isEnd = False
14
+ self.boardHash = None
15
+ # init p1 plays first
16
+ self.playerSymbol = 1
17
+
18
+ # get unique hash of current board state
19
+ def getHash(self):
20
+ self.boardHash = str(self.board.reshape(BOARD_COLS * BOARD_ROWS))
21
+ return self.boardHash
22
+
23
+ def winner(self):
24
+ # row
25
+ for i in range(BOARD_ROWS):
26
+ if sum(self.board[i, :]) == 3:
27
+ self.isEnd = True
28
+ return 1
29
+ if sum(self.board[i, :]) == -3:
30
+ self.isEnd = True
31
+ return -1
32
+ # col
33
+ for i in range(BOARD_COLS):
34
+ if sum(self.board[:, i]) == 3:
35
+ self.isEnd = True
36
+ return 1
37
+ if sum(self.board[:, i]) == -3:
38
+ self.isEnd = True
39
+ return -1
40
+ # diagonal
41
+ diag_sum1 = sum([self.board[i, i] for i in range(BOARD_COLS)])
42
+ diag_sum2 = sum([self.board[i, BOARD_COLS - i - 1] for i in range(BOARD_COLS)])
43
+ diag_sum = max(abs(diag_sum1), abs(diag_sum2))
44
+ if diag_sum == 3:
45
+ self.isEnd = True
46
+ if diag_sum1 == 3 or diag_sum2 == 3:
47
+ return 1
48
+ else:
49
+ return -1
50
+
51
+ # tie
52
+ # no available positions
53
+ if len(self.availablePositions()) == 0:
54
+ self.isEnd = True
55
+ return 0
56
+ # not end
57
+ self.isEnd = False
58
+ return None
59
+
60
+ def availablePositions(self):
61
+ positions = []
62
+ for i in range(BOARD_ROWS):
63
+ for j in range(BOARD_COLS):
64
+ if self.board[i, j] == 0:
65
+ positions.append((i, j)) # need to be tuple
66
+ return positions
67
+
68
+ def updateState(self, position):
69
+ self.board[position] = self.playerSymbol
70
+ # switch to another player
71
+ self.playerSymbol = -1 if self.playerSymbol == 1 else 1
72
+
73
+ # only when game ends
74
+ def giveReward(self):
75
+ result = self.winner()
76
+ # backpropagate reward
77
+ if result == 1:
78
+ self.p1.feedReward(1)
79
+ self.p2.feedReward(0)
80
+ elif result == -1:
81
+ self.p1.feedReward(0)
82
+ self.p2.feedReward(1)
83
+ else:
84
+ self.p1.feedReward(0.1)
85
+ self.p2.feedReward(0.5)
86
+
87
+ # board reset
88
+ def reset(self):
89
+ self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
90
+ self.boardHash = None
91
+ self.isEnd = False
92
+ self.playerSymbol = 1
93
+
94
+ def playwithbot(self, rounds=100):
95
+ for i in range(rounds):
96
+ if i % 1000 == 0:
97
+ print("Rounds {}".format(i))
98
+ while not self.isEnd:
99
+ # Player 1
100
+ positions = self.availablePositions()
101
+ p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
102
+ # take action and upate board state
103
+ self.updateState(p1_action)
104
+ board_hash = self.getHash()
105
+ self.p1.addState(board_hash)
106
+ # check board status if it is end
107
+
108
+ win = self.winner()
109
+ if win is not None:
110
+ # self.showBoard()
111
+ # ended with p1 either win or draw
112
+ self.giveReward()
113
+ self.p1.reset()
114
+ self.p2.reset()
115
+ self.reset()
116
+ break
117
+
118
+ else:
119
+ # Player 2
120
+ positions = self.availablePositions()
121
+ p2_action = self.p2.chooseAction(positions, self.board, self.playerSymbol)
122
+ self.updateState(p2_action)
123
+ board_hash = self.getHash()
124
+ self.p2.addState(board_hash)
125
+
126
+ win = self.winner()
127
+ if win is not None:
128
+ # self.showBoard()
129
+ # ended with p2 either win or draw
130
+ self.giveReward()
131
+ self.p1.reset()
132
+ self.p2.reset()
133
+ self.reset()
134
+ break
135
+
136
+ # play with human
137
+ def playwithhuman(self):
138
+ while not self.isEnd:
139
+ # Player 1
140
+ positions = self.availablePositions()
141
+ p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
142
+ # take action and upate board state
143
+ self.updateState(p1_action)
144
+ self.showBoard()
145
+ # check board status if it is end
146
+ win = self.winner()
147
+ if win is not None:
148
+ if win == 1:
149
+ print(self.p1.name, "wins!")
150
+ else:
151
+ print("tie!")
152
+ self.reset()
153
+ break
154
+
155
+ else:
156
+ # Player 2
157
+ positions = self.availablePositions()
158
+ p2_action = self.p2.chooseAction(positions)
159
+
160
+ self.updateState(p2_action)
161
+ self.showBoard()
162
+ win = self.winner()
163
+ if win is not None:
164
+ if win == -1:
165
+ print(self.p2.name, "wins!")
166
+ else:
167
+ print("tie!")
168
+ self.reset()
169
+ break
170
+
171
+ def showBoard(self):
172
+ # p1: x p2: o
173
+ for i in range(0, BOARD_ROWS):
174
+ print('-------------')
175
+ out = '| '
176
+ for j in range(0, BOARD_COLS):
177
+ if self.board[i, j] == 1:
178
+ token = 'x'
179
+ if self.board[i, j] == -1:
180
+ token = 'o'
181
+ if self.board[i, j] == 0:
182
+ token = ' '
183
+ out += token + ' | '
184
+ print(out)
185
+ print('-------------')
186
+
187
+
188
+ class Player:
189
+ def __init__(self, name, exp_rate=0.3):
190
+ self.name = name
191
+ self.states = [] # record all positions taken
192
+ self.lr = 0.2
193
+ self.exp_rate = exp_rate
194
+ self.decay_gamma = 0.9
195
+ self.states_value = {} # state -> value
196
+ self.loadPolicy('policy_' + str(self.name)) # Load the pre-trained policy
197
+
198
+ def getHash(self, board):
199
+ boardHash = str(board.reshape(BOARD_COLS * BOARD_ROWS))
200
+ return boardHash
201
+
202
+ def chooseAction(self, positions, current_board, symbol):
203
+ if np.random.uniform(0, 1) <= self.exp_rate:
204
+ # take random action
205
+ idx = np.random.choice(len(positions))
206
+ action = positions[idx]
207
+ else:
208
+ value_max = -999
209
+ for p in positions:
210
+ next_board = current_board.copy()
211
+ next_board[p] = symbol
212
+ next_boardHash = self.getHash(next_board)
213
+ value = 0 if self.states_value.get(next_boardHash) is None else self.states_value.get(next_boardHash)
214
+ # print("value", value)
215
+ if value >= value_max:
216
+ value_max = value
217
+ action = p
218
+ # print("{} takes action {}".format(self.name, action))
219
+ return action
220
+
221
+ # append a hash state
222
+ def addState(self, state):
223
+ self.states.append(state)
224
+
225
+ # at the end of game, backpropagate and update states value
226
+ def feedReward(self, reward):
227
+ for st in reversed(self.states):
228
+ if self.states_value.get(st) is None:
229
+ self.states_value[st] = 0
230
+ self.states_value[st] += self.lr * (self.decay_gamma * reward - self.states_value[st])
231
+ reward = self.states_value[st]
232
+
233
+ def reset(self):
234
+ self.states = []
235
+
236
+ def savePolicy(self):
237
+ fw = open('policy_' + str(self.name), 'wb')
238
+ pickle.dump(self.states_value, fw)
239
+ fw.close()
240
+
241
+ def loadPolicy(self, file):
242
+ fr = open(file, 'rb')
243
+ self.states_value = pickle.load(fr)
244
+ fr.close()
245
+
246
+
247
+ class HumanPlayer:
248
+ def __init__(self, name):
249
+ self.name = name
250
+
251
+ def chooseAction(self, positions):
252
+ while True:
253
+ row = int(input("Input your action row:"))
254
+ col = int(input("Input your action col:"))
255
+ action = (row, col)
256
+ if action in positions:
257
+ return action
258
+
259
+ # append a hash state
260
+ def addState(self, state):
261
+ pass
262
+
263
+ # at the end of game, backpropagate and update states value
264
+ def feedReward(self, reward):
265
+ pass
266
+
267
+ def reset(self):
268
+ pass
269
+
270
+
271
+ if __name__ == "__main__":
272
+ # training
273
+ p1 = Player("p1")
274
+ p2 = Player("p2")
275
+
276
+ st = State(p1, p2)
277
+ print("training...")
278
+ st.playwithbot(200000)
279
+
280
+ p1.savePolicy()
281
+ p2.savePolicy()
282
+
283
+ # # play with human
284
+ # p1 = Player("computer", exp_rate=0)
285
+ # p1.loadPolicy("policy_p1")
286
+
287
+ # p2 = HumanPlayer("human")
288
+
289
+ # st = State(p1, p2)
290
+ # st.playwithhuman()