import datetime

import pandas as pd
import torch
from tqdm import tqdm


def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
    r"""
    Padding sequence for preparation to tensorDataset
    :param sequences: data to pad
    :param batch_first: boolean indicating whether the batch are in first dimension
    :param padding_value: the value for pad
    :param max_len: the maximum length
    :return: padding sequences
    """
    max_size = sequences[0].size()
    trailing_dims = max_size[1:]
    if batch_first:
        out_dims = (len(sequences), max_len) + trailing_dims
    else:
        out_dims = (max_len, len(sequences)) + trailing_dims

    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        # use index notation to prevent duplicate references to the tensor
        if batch_first:
            out_tensor[i, :length, ...] = tensor
        else:
            out_tensor[:length, i, ...] = tensor

    return out_tensor


def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
    r"""
    Preparing csv dataset
    :param csv_path:
    :param nrows:
    :param chunksize:
    :return:
    """
    print("Loading csv...")

    rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1  # minus the header
    chunk_list = []

    if rows > nrows:
        rows = nrows
        chunksize = nrows

    with tqdm(total=rows, desc='Rows read: ') as bar:
        for chunk in pd.read_csv(csv_path, converters={'Y': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=rows):
            chunk_list.append(chunk)
            bar.update(len(chunk))

    df = pd.concat((f for f in chunk_list), axis=0)

    return df


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round(elapsed))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))