Fixed training of three models
This commit is contained in:
19
README.md
19
README.md
@@ -1,5 +1,24 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
library_name: pytorch
|
||||||
|
tags:
|
||||||
|
- reinforcement-learning
|
||||||
|
- game-ai
|
||||||
|
- tic-tac-toe
|
||||||
|
- pytorch
|
||||||
|
---
|
||||||
|
|
||||||
# Ultimate Tic Tac Toe Deep Learning Bot
|
# Ultimate Tic Tac Toe Deep Learning Bot
|
||||||
|
|
||||||
|
Model for playing Ultimate Tic Tac Toe
|
||||||
|
|
||||||
|
## Available checkpoints
|
||||||
|
|
||||||
|
- `checkpoints/medium.pth`: medium-difficulty checkpoint.
|
||||||
|
- `checkpoints/hard.pth`: hard-difficulty checkpoint.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
**Usage**
|
**Usage**
|
||||||
Run `python run.py --help` for help.
|
Run `python run.py --help` for help.
|
||||||
|
|
||||||
|
|||||||
38
hf_upload.py
Normal file
38
hf_upload.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, create_repo
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ID = "Syndria98/ultimate-tic-tac-toe-ai"
|
||||||
|
|
||||||
|
FILES_TO_UPLOAD = {
|
||||||
|
"model.py": "model.py",
|
||||||
|
"game.py": "game.py",
|
||||||
|
"run.py": "run.py",
|
||||||
|
"README.md": "README.md",
|
||||||
|
"easy.pth": "models/easy.pth",
|
||||||
|
"medium.pth": "models/medium.pth",
|
||||||
|
"hard.pth": "models/hard.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
for local_path, repo_path in FILES_TO_UPLOAD.items():
|
||||||
|
path = Path(local_path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Missing file: {local_path}")
|
||||||
|
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=str(path),
|
||||||
|
path_in_repo=repo_path,
|
||||||
|
repo_id=REPO_ID,
|
||||||
|
repo_type="model",
|
||||||
|
)
|
||||||
|
print(f"Uploaded {local_path} -> {repo_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
train.sh
Executable file
23
train.sh
Executable file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Override with DEVICE=... ./train.sh if you want to force a specific device.
|
||||||
|
DEVICE="${DEVICE:-$(python -c 'import torch; print("cuda" if torch.cuda.is_available() else "cpu")')}"
|
||||||
|
OUTPUT_CHECKPOINT="${OUTPUT_CHECKPOINT:-easy.pth}"
|
||||||
|
|
||||||
|
# Example heavier runs:
|
||||||
|
DEVICE=cuda python run.py train --checkpoint hard.pth --arena-compare-games 4 --arena-compare-simulations 4 --num-simulations 25
|
||||||
|
DEVICE=cuda python run.py train --checkpoint medium.pth --arena-compare-games 3 --arena-compare-simulations 4 --num-simulations 20 --epochs 5 --num-eps 15 --num-iters 30
|
||||||
|
python run.py train \
|
||||||
|
--checkpoint "$OUTPUT_CHECKPOINT" \
|
||||||
|
--arena-compare-games 0 \
|
||||||
|
--arena-compare-simulations 2 \
|
||||||
|
--device "$DEVICE" \
|
||||||
|
--num-simulations 20 \
|
||||||
|
--epochs 3 \
|
||||||
|
--num-eps 10 \
|
||||||
|
--num-iters 15 \
|
||||||
|
"$@"
|
||||||
|
|
||||||
|
echo "Training finished!"
|
||||||
Reference in New Issue
Block a user