Nharen commited on
Commit
dee25bb
·
verified ·
1 Parent(s): 700d4b0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -28
README.md CHANGED
@@ -21,24 +21,31 @@ model-index:
21
  value: 100
22
  ---
23
 
24
- ### Model Architectures
25
 
26
- #### CartPole-v1 (DQN)
27
- * **File:** `Cartpole.pth`
28
- * **Algorithm:** Deep Q-Network
29
- * **Input:** 4 discrete observations
30
- * **Output:** 2 discrete actions (Left/Right)
 
 
31
  * **Network Structure:**
32
- * **Input Layer:** 4 -> 128 (Linear)
33
- * **Activation:** ReLU
34
- * **Hidden Layer:** 128 -> 128 (Linear)
35
- * **Activation:** ReLU
36
- * **Output Layer:** 128 -> 2 (Linear)
37
 
 
38
 
39
- ### Test Code
 
 
40
 
41
- ```
 
 
 
 
42
  import torch
43
  import torch.nn as nn
44
  import gymnasium as gym
@@ -46,29 +53,32 @@ import numpy as np
46
  from huggingface_hub import hf_hub_download
47
 
48
  class MatchedNet(nn.Module):
49
- def __init__(self, n_observations=4, n_actions=2):
50
- super(MatchedNet, self).__init__()
51
- self.layer1 = nn.Linear(n_observations, 128)
52
  self.layer2 = nn.Linear(128, 128)
53
- self.layer3 = nn.Linear(128, n_actions)
54
 
55
  def forward(self, x):
56
  x = torch.relu(self.layer1(x))
57
  x = torch.relu(self.layer2(x))
58
  return self.layer3(x)
59
 
60
- def run_test():
61
- repo_id = "Nharen/Reward_Rush_DQN_Cart_Pole"
62
- path = hf_hub_download(repo_id=repo_id, filename="Cartpole.pth")
63
-
64
  model = MatchedNet()
65
  state_dict = torch.load(path, map_location='cpu', weights_only=True)
 
 
 
 
66
  model.load_state_dict(state_dict)
67
  model.eval()
68
 
69
  env = gym.make("CartPole-v1")
70
- rewards = []
71
-
72
  for _ in range(100):
73
  state, _ = env.reset()
74
  episode_reward = 0
@@ -76,17 +86,16 @@ def run_test():
76
  while not done:
77
  state_t = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
78
  with torch.no_grad():
79
- action = model(state_t).max(1)[1].item()
80
  state, reward, terminated, truncated, _ = env.step(action)
81
  episode_reward += reward
82
  done = terminated or truncated
83
- rewards.append(episode_reward)
84
 
85
- print(f"Average Reward: {np.mean(rewards):.2f}")
86
  env.close()
87
 
88
  if __name__ == "__main__":
89
- run_test()
90
-
91
  ```
92
 
 
21
  value: 100
22
  ---
23
 
24
+ # Reward Rush: CartPole DQN
25
 
26
+ This repository contains the cleaned weights for a Deep Q-Network agent trained for the CartPole-v1 environment.
27
+
28
+ ## Model Architecture
29
+
30
+ The model utilizes a multi-layer perceptron structure designed for low-latency inference:
31
+ * **Input:** 4 state observations
32
+ * **Output:** 2 discrete actions
33
  * **Network Structure:**
34
+ * Linear(4, 128) -> ReLU
35
+ * Linear(128, 128) -> ReLU
36
+ * Linear(128, 2)
 
 
37
 
38
+ ## Common Implementation Mistakes to Avoid
39
 
40
+ 1. **Variable Naming:** The weights are mapped to specific names: `layer1`, `layer2`, and `layer3`. Using generic names like `fc1` or `nn.Sequential` will result in a loading error.
41
+ 2. **Missing Batch Dimension:** The model expects a batch dimension. Input states must be wrapped using `unsqueeze(0)` before inference.
42
+ 3. **Inference Logic:** The model outputs raw Q-values for both actions. Use `argmax(dim=1)` to select the correct action index for the environment.
43
 
44
+ ## Download and Test Code
45
+
46
+ This script downloads the weights from the Hugging Face repository, initializes the environment, and evaluates the agent over 100 test episodes.
47
+
48
+ ```python
49
  import torch
50
  import torch.nn as nn
51
  import gymnasium as gym
 
53
  from huggingface_hub import hf_hub_download
54
 
55
  class MatchedNet(nn.Module):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.layer1 = nn.Linear(4, 128)
59
  self.layer2 = nn.Linear(128, 128)
60
+ self.layer3 = nn.Linear(128, 2)
61
 
62
  def forward(self, x):
63
  x = torch.relu(self.layer1(x))
64
  x = torch.relu(self.layer2(x))
65
  return self.layer3(x)
66
 
67
+ def run_cartpole_test():
68
+ path = hf_hub_download(repo_id="Nharen/Reward_Rush_DQN_Cart_Pole", filename="Cartpole.pth")
69
+
 
70
  model = MatchedNet()
71
  state_dict = torch.load(path, map_location='cpu', weights_only=True)
72
+
73
+ if isinstance(state_dict, dict) and "policy_net_state_dict" in state_dict:
74
+ state_dict = state_dict["policy_net_state_dict"]
75
+
76
  model.load_state_dict(state_dict)
77
  model.eval()
78
 
79
  env = gym.make("CartPole-v1")
80
+ total_rewards = []
81
+
82
  for _ in range(100):
83
  state, _ = env.reset()
84
  episode_reward = 0
 
86
  while not done:
87
  state_t = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
88
  with torch.no_grad():
89
+ action = model(state_t).argmax(dim=1).item()
90
  state, reward, terminated, truncated, _ = env.step(action)
91
  episode_reward += reward
92
  done = terminated or truncated
93
+ total_rewards.append(episode_reward)
94
 
95
+ print(f"Average Reward: {np.mean(total_rewards)}")
96
  env.close()
97
 
98
  if __name__ == "__main__":
99
+ run_cartpole_test()
 
100
  ```
101