from utils import sqlConnect


def getSpeciesIntroData(ncbi_taxon_id, site, disease):
    connect = sqlConnect.MySQLConnection()
    sql_desc_disease = 'SELECT note FROM mesh_data WHERE uid="{}"'.format(disease)
    desc_disease = connect.queryOne(sql_desc_disease)['note']
    sql_intro_disease_total = 'SELECT COUNT(1) run_total FROM mbodymap_samples WHERE disease="{}" AND BodySite="{}"'.format(
        disease, site)
    intro_disease_total = connect.queryOne(sql_intro_disease_total)['run_total']
    sql_intro_disease_processed = 'SELECT COUNT(1) run_processed FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id'.format(
        disease, site)
    intro_disease_processed = connect.queryOne(sql_intro_disease_processed)['run_processed']
    sql_intro_disease_valid = 'SELECT COUNT(1) run_valid FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id AND t2.QCStatus=1'.format(
        disease, site)
    intro_disease_valid = connect.queryOne(sql_intro_disease_valid)['run_valid']
    sql_intro_species_inDisease_valid = 'SELECT loaded_uid_num_QC1 FROM mbodymap_species_to_samples WHERE disease="{}" AND BodySite="{}" AND ncbi_taxon_id={}'.format(
        disease, site, ncbi_taxon_id)
    # print(sql_intro_species_inDisease_valid)
    intro_species_inDisease_valid = connect.queryOne(sql_intro_species_inDisease_valid)['loaded_uid_num_QC1']
    result = {
        'desc_disease': desc_disease,
        'intro_disease_total': '{:,}'.format(intro_disease_total),
        'intro_disease_processed': '{:,}'.format(intro_disease_processed),
        'intro_disease_valid': '{:,}'.format(intro_disease_valid),
        'intro_species_inDisease_valid': '{:,}'.format(intro_species_inDisease_valid)
    }
    return result


def getSpeciesPlotData(ncbi_taxon_id, site, disease):
    connect = sqlConnect.MySQLConnection()
    if disease == 'D006262':
        diseaseList = [disease]
    else:
        diseaseList = ['D006262', disease]
    linePlot_results = []
    boxPlot_results = []
    sql_runs_total = 'SELECT COUNT(1) total FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
        site, disease, ncbi_taxon_id)
    runs_total = connect.queryOne(sql_runs_total)['total']
    # 样本量小于10的不计算
    if int(runs_total) < 10:
        results = {
            'linePlot': [],
            'boxPlot': []
        }
    else:
        for item in diseaseList:
            sql_total_disease_of_taxon_in_site = 'SELECT COUNT(1) total FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id AND t2.QCStatus=1'.format(
                item, site)
            total_disease_of_taxon_in_site = connect.queryOne(sql_total_disease_of_taxon_in_site)['total']
            sql_disease_name = 'SELECT term FROM mesh_data WHERE uid="{}"'.format(item)
            disease_name = connect.queryOne(sql_disease_name)['term']
            sql_max_ra = 'SELECT relative_abundance_max FROM mbodymap_species_to_samples WHERE ncbi_taxon_id={} AND BodySite="{}" AND disease="{}"'.format(
                ncbi_taxon_id, site, item)
            # print(sql_max_ra)
            max_ra = connect.queryOne(sql_max_ra)
            if max_ra is None:
                max_ra = 0
            else:
                max_ra = connect.queryOne(sql_max_ra)['relative_abundance_max']
            sql_all_reiative_abundance = 'SELECT t1.relative_abundance ra FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
                site, item, ncbi_taxon_id)
            all_reiative_abundance = connect.query(sql_all_reiative_abundance)
            ## 设置标尺
            max_relative_abundance = int(max_ra)
            if max_relative_abundance % 2 == 0:
                max_relative_abundance += 1
            else:
                max_relative_abundance += 2
            roundX = []
            if max_relative_abundance == 100:
                for i in range(1, 101, 2):
                    roundX.append(i)
            else:
                for i in range(1, max_relative_abundance + 2, 2):
                    roundX.append(i)
            ## 统计尺度值的总数
            roundY = []
            box_value = []
            for rule in roundX:
                rule_value = 0
                for relative_abundance in all_reiative_abundance:
                    box_value.append(relative_abundance['ra'])
                    if relative_abundance['ra'] >= rule:
                        rule_value += 1
                roundY.append(rule_value / int(total_disease_of_taxon_in_site) * 100)
            if item == 'D006262':
                linePlot_results.append({
                    'type': 'scatter',
                    'x': roundX,
                    'y': roundY,
                    'name': disease_name,
                    'mode': 'lines+markers',
                    'marker': {
                        'color': 'green'
                    },
                    'line': {
                        'color': 'green'
                    }
                })
                boxPlot_results.append({
                    'type': 'box',
                    'x': box_value,
                    'name': disease_name,
                    'boxpoints': 'Outliers',
                    'width': 0.1,
                    'fillcolor': 'green',
                    'marker': {
                        'color': 'green'
                    },
                    'line': {
                        'color': 'green'
                    }
                })
            else:
                linePlot_results.append({
                    'type': 'scatter',
                    'x': roundX,
                    'y': roundY,
                    'name': disease_name,
                    'mode': 'lines+markers'
                })
                boxPlot_results.append({
                    'type': 'box',
                    'x': box_value,
                    'name': disease_name,
                    'boxpoints': 'Outliers',
                    'width': 0.1
                })
        results = {
            'linePlot': linePlot_results,
            'boxPlot': boxPlot_results
        }
    return results


def getGenusIntroData(ncbi_taxon_id, site, disease):
    connect = sqlConnect.MySQLConnection()
    sql_desc_disease = 'SELECT note FROM mesh_data WHERE uid="{}"'.format(disease)
    desc_disease = connect.queryOne(sql_desc_disease)['note']
    sql_intro_disease_total = 'SELECT COUNT(1) run_total FROM mbodymap_samples WHERE disease="{}" AND BodySite="{}"'.format(
        disease, site)
    intro_disease_total = connect.queryOne(sql_intro_disease_total)['run_total']
    sql_intro_disease_processed = 'SELECT COUNT(1) run_processed FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id'.format(
        disease, site)
    intro_disease_processed = connect.queryOne(sql_intro_disease_processed)['run_processed']
    sql_intro_disease_valid = 'SELECT COUNT(1) run_valid FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id AND t2.QCStatus=1'.format(
        disease, site)
    intro_disease_valid = connect.queryOne(sql_intro_disease_valid)['run_valid']
    sql_intro_species_inDisease_valid = 'SELECT loaded_uid_num_QC1 FROM mbodymap_genus_to_samples WHERE disease="{}" AND BodySite="{}" AND ncbi_taxon_id={}'.format(
        disease, site, ncbi_taxon_id)
    # print(sql_intro_species_inDisease_valid)
    intro_species_inDisease_valid = connect.queryOne(sql_intro_species_inDisease_valid)['loaded_uid_num_QC1']
    result = {
        'desc_disease': desc_disease,
        'intro_disease_total': '{:,}'.format(intro_disease_total),
        'intro_disease_processed': '{:,}'.format(intro_disease_processed),
        'intro_disease_valid': '{:,}'.format(intro_disease_valid),
        'intro_species_inDisease_valid': '{:,}'.format(intro_species_inDisease_valid)
    }
    return result



def getGenusPlotData(ncbi_taxon_id, site, disease):
    connect = sqlConnect.MySQLConnection()
    if disease == 'D006262':
        diseaseList = [disease]
    else:
        diseaseList = ['D006262', disease]
    linePlot_results = []
    boxPlot_results = []
    sql_runs_total = 'SELECT COUNT(1) total FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
        site, disease, ncbi_taxon_id)
    runs_total = connect.queryOne(sql_runs_total)['total']
    # 样本量小于10的不计算
    if int(runs_total) < 10:
        results = {
            'linePlot': [],
            'boxPlot': []
        }
    else:
        for item in diseaseList:
            sql_total_disease_of_taxon_in_site = 'SELECT COUNT(1) total FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id AND t2.QCStatus=1'.format(
                item, site)
            total_disease_of_taxon_in_site = connect.queryOne(sql_total_disease_of_taxon_in_site)['total']
            sql_disease_name = 'SELECT term FROM mesh_data WHERE uid="{}"'.format(item)
            disease_name = connect.queryOne(sql_disease_name)['term']
            sql_max_ra = 'SELECT relative_abundance_max FROM mbodymap_genus_to_samples WHERE ncbi_taxon_id={} AND BodySite="{}" AND disease="{}"'.format(
                ncbi_taxon_id, site, item)
            print(sql_max_ra)
            max_ra = connect.queryOne(sql_max_ra)
            if max_ra is None:
                max_ra = 0
            else:
                max_ra = connect.queryOne(sql_max_ra)['relative_abundance_max']
            sql_all_reiative_abundance = 'SELECT t1.relative_abundance ra FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
                site, item, ncbi_taxon_id)
            all_reiative_abundance = connect.query(sql_all_reiative_abundance)
            ## 设置标尺
            max_relative_abundance = int(max_ra)
            if max_relative_abundance % 2 == 0:
                max_relative_abundance += 1
            else:
                max_relative_abundance += 2
            roundX = []
            if max_relative_abundance == 100:
                for i in range(1, 101, 2):
                    roundX.append(i)
            else:
                for i in range(1, max_relative_abundance + 2, 2):
                    roundX.append(i)
            ## 统计尺度值的总数
            roundY = []
            box_value = []
            for rule in roundX:
                rule_value = 0
                for relative_abundance in all_reiative_abundance:
                    box_value.append(relative_abundance['ra'])
                    if relative_abundance['ra'] >= rule:
                        rule_value += 1
                roundY.append(rule_value / int(total_disease_of_taxon_in_site) * 100)
            if item == 'D006262':
                linePlot_results.append({
                    'type': 'scatter',
                    'x': roundX,
                    'y': roundY,
                    'name': disease_name,
                    'mode': 'lines+markers',
                    'marker': {
                        'color': 'green'
                    },
                    'line': {
                        'color': 'green'
                    }
                })
                boxPlot_results.append({
                    'type': 'box',
                    'x': box_value,
                    'name': disease_name,
                    'boxpoints': 'Outliers',
                    'width': 0.1,
                    'fillcolor': 'green',
                    'marker': {
                        'color': 'green'
                    },
                    'line': {
                        'color': 'green'
                    }
                })
            else:
                linePlot_results.append({
                    'type': 'scatter',
                    'x': roundX,
                    'y': roundY,
                    'name': disease_name,
                    'mode': 'lines+markers'
                })
                boxPlot_results.append({
                    'type': 'box',
                    'x': box_value,
                    'name': disease_name,
                    'boxpoints': 'Outliers',
                    'width': 0.1
                })
        results = {
            'linePlot': linePlot_results,
            'boxPlot': boxPlot_results
        }
    return results
