import os
import argparse

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
import torchvision.datasets as datasets
import torchvision.transforms as transforms


def main():
    # Configs
    parser = argparse.ArgumentParser()

    parser.add_argument('--param', type=str, default='default value', help='help text')

    # Data configs
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=1)
    parser.add_argument('--no_shuffle', action='store_true')

    # Model configs
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--num_features', type=int, default=32)

    # Trainer configs
    parser.add_argument('--num_epochs', type=int, default=10)
    parser.add_argument('--output_dir', type=str, required=True)

    # Initialize args
    args = parser.parse_args()

    # Parameters
    log_dir = os.path.join(args.output_dir, 'log')
    summary_writer = SummaryWriter(log_dir=log_dir)
    model_path = os.path.join(args.output_dir, 'model.pt')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")

    # Prepare dataset
    transform = transforms.ToTensor()
    MNIST_IMG_SIZE = 28
    MNIST_CHANNEL_SIZE = 1
    MNIST_NUM_CLASSES = 10
    mnist_train = datasets.MNIST(root=args.data_dir,
                                 train=True,
                                 transform=transform)

    mnist_test = datasets.MNIST(root=args.data_dir,
                                train=False,
                                transform=transform)

    # Build dataloader
    train_loader = DataLoader(dataset=mnist_train,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=not args.no_shuffle)
    test_loader = DataLoader(dataset=mnist_test,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             shuffle=False)

    # Initialize model
    model = torch.nn.Sequential(
        torch.nn.Conv2d(MNIST_CHANNEL_SIZE, args.num_features, 3, padding=1, bias=True),
        torch.nn.ReLU(),
        torch.nn.Flatten(),
        torch.nn.Linear(MNIST_IMG_SIZE * MNIST_IMG_SIZE * args.num_features, MNIST_NUM_CLASSES, bias=False)
    ).to(device)

    # Define loss function & optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    # Training loop
    for epoch in range(1, args.num_epochs + 1):
        avg_loss = 0
        for batch_idx, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)

            # Prediction
            pred = model(x)
            loss = criterion(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            avg_loss += loss.item()
            if batch_idx % 100 == 0:
                print("[Epoch: {:>4}][Batch: {:>4}] step loss = {:>.9}".format(epoch, batch_idx, loss.item()))

        avg_loss /= len(train_loader)
        summary_writer.add_scalar('loss', avg_loss, global_step=epoch)

    print('Training Finished!')

    # Test model
    correct = 0
    count = 0

    for x, y in test_loader:
        x = x.to(device)
        y = y.to(device)

        pred = model(x)

        max_val, max_args = torch.max(pred, 1)

        correct += torch.sum(max_args == y).item()
        count += len(y)

    print('Testing Finished!')

    accuracy = correct / count
    print('Test accuracy:', accuracy)

    torch.jit.save(torch.jit.script(model), model_path)
    print('Model saved!')


if __name__ == '__main__':
    main()