from utils import sqlConnect
from service.taxonsPageService import getSpeciesPhenotype, getGenusPhenotype
import json


def speciesDetails(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    # 表格数据
    speciesTableRes = getSpeciesPhenotype(ncbi_taxon_id, site)
    # barPlot数据
    barPlot_results = []
    linePlot_results = []
    xPlot = []
    yPlot = []
    sql_barPlot = 'SELECT disease, loaded_uid_num_QC1, relative_abundance_max FROM mbodymap_species_to_samples WHERE ncbi_taxon_id={} AND BodySite="{}" AND loaded_uid_num_QC1>10'.format(
        ncbi_taxon_id, site)
    phenotypes = connect.query(sql_barPlot)
    for phenotype in phenotypes:
        # barPlot
        sql_total_disease_of_taxon_in_site = 'SELECT COUNT(1) total FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id AND t2.QCStatus=1'.format(
            phenotype['disease'], site)
        total_disease_of_taxon_in_site = connect.queryOne(sql_total_disease_of_taxon_in_site)['total']
        sql_disease_name = 'SELECT term FROM mesh_data WHERE uid="{}"'.format(phenotype['disease'])
        disease_name = connect.queryOne(sql_disease_name)['term']
        barPlot_results.append({
            'disease': disease_name,
            'prevalence': phenotype['loaded_uid_num_QC1'] / int(total_disease_of_taxon_in_site) * 100
        })
        # linePlot
        sql_all_reiative_abundance = 'SELECT t1.relative_abundance ra FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
            site, phenotype['disease'], ncbi_taxon_id)
        all_reiative_abundance = connect.query(sql_all_reiative_abundance)
        ## 设置标尺
        max_relative_abundance = int(phenotype['relative_abundance_max'])
        if max_relative_abundance % 2 == 0:
            max_relative_abundance += 1
        else:
            max_relative_abundance += 2
        roundX = []
        if max_relative_abundance == 100:
            for i in range(1, 101, 2):
                roundX.append(i)
        else:
            for i in range(1, max_relative_abundance + 2, 2):
                roundX.append(i)
        ## 统计尺度值的总数
        roundY = []
        for rule in roundX:
            rule_value = 0
            for relative_abundance in all_reiative_abundance:
                if relative_abundance['ra'] >= rule:
                    rule_value += 1
            roundY.append(rule_value / int(total_disease_of_taxon_in_site) * 100)
        if phenotype['disease'] == 'D006262':
            linePlot_results.append({
                'type': 'scatter',
                'x': roundX,
                'y': roundY,
                'name': disease_name,
                'mode': 'lines+markers',
                'marker': {
                    'color': 'green'
                },
                'line': {
                    'color': 'green'
                }
            })
        else:
            linePlot_results.append({
                'type': 'scatter',
                'x': roundX,
                'y': roundY,
                'name': disease_name,
                'mode': 'lines+markers'
            })
    # print(barPlot_results)
    results_sorted = sorted(barPlot_results, key=lambda result: result['prevalence'], reverse=True)
    # print(results_sorted)
    for results_sorted_item in results_sorted:
        xPlot.append(format(results_sorted_item['prevalence'], '.1f'))
        yPlot.append(results_sorted_item['disease'])
    results = {
        'speciesTableRes': speciesTableRes,
        'xPlot': xPlot,
        'yPlot': yPlot,
        'linePlot': linePlot_results
    }
    return results


def genusDetails(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    # 表格数据
    speciesTableRes = getGenusPhenotype(ncbi_taxon_id, site)
    # barPlot数据
    barPlot_results = []
    linePlot_results = []
    xPlot = []
    yPlot = []
    sql_barPlot = 'SELECT disease, loaded_uid_num_QC1, relative_abundance_max FROM mbodymap_genus_to_samples WHERE ncbi_taxon_id={} AND BodySite="{}" AND loaded_uid_num_QC1>10'.format(
        ncbi_taxon_id, site)
    phenotypes = connect.query(sql_barPlot)
    for phenotype in phenotypes:
        # barPlot
        sql_total_disease_of_taxon_in_site = 'SELECT COUNT(1) total FROM mbodymap_samples t1, mbodymap_loaded_samples t2 WHERE disease="{}" AND BodySite="{}" AND t1.run_id=t2.accession_id AND t2.QCStatus=1'.format(
            phenotype['disease'], site)
        total_disease_of_taxon_in_site = connect.queryOne(sql_total_disease_of_taxon_in_site)['total']
        sql_disease_name = 'SELECT term FROM mesh_data WHERE uid="{}"'.format(phenotype['disease'])
        disease_name = connect.queryOne(sql_disease_name)['term']
        barPlot_results.append({
            'disease': disease_name,
            'prevalence': phenotype['loaded_uid_num_QC1'] / int(total_disease_of_taxon_in_site) * 100
        })
        # linePlot
        sql_all_reiative_abundance = 'SELECT t1.relative_abundance ra FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
            site, phenotype['disease'], ncbi_taxon_id)
        all_reiative_abundance = connect.query(sql_all_reiative_abundance)
        ## 设置标尺
        max_relative_abundance = int(phenotype['relative_abundance_max'])
        if max_relative_abundance % 2 == 0:
            max_relative_abundance += 1
        else:
            max_relative_abundance += 2
        roundX = []
        if max_relative_abundance == 100:
            for i in range(1, 101, 2):
                roundX.append(i)
        else:
            for i in range(1, max_relative_abundance + 2, 2):
                roundX.append(i)
        ## 统计尺度值的总数
        roundY = []
        for rule in roundX:
            rule_value = 0
            for relative_abundance in all_reiative_abundance:
                if relative_abundance['ra'] >= rule:
                    rule_value += 1
            roundY.append(rule_value / int(total_disease_of_taxon_in_site) * 100)
        if phenotype['disease'] == 'D006262':
            linePlot_results.append({
                'type': 'scatter',
                'x': roundX,
                'y': roundY,
                'name': disease_name,
                'mode': 'lines+markers',
                'marker': {
                    'color': 'green'
                },
                'line': {
                    'color': 'green'
                }
            })
        else:
            linePlot_results.append({
                'type': 'scatter',
                'x': roundX,
                'y': roundY,
                'name': disease_name,
                'mode': 'lines+markers'
            })
    # print(barPlot_results)
    results_sorted = sorted(barPlot_results, key=lambda result: result['prevalence'], reverse=True)
    # print(results_sorted)
    for results_sorted_item in results_sorted:
        xPlot.append(format(results_sorted_item['prevalence'], '.1f'))
        yPlot.append(results_sorted_item['disease'])
    results = {
        'speciesTableRes': speciesTableRes,
        'xPlot': xPlot,
        'yPlot': yPlot,
        'linePlot': linePlot_results
    }
    return results


def getGenusBoxPlot(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    # 表格数据
    boxPlot_results = []
    sql_barPlot = 'SELECT disease, loaded_uid_num_QC1, relative_abundance_max FROM mbodymap_genus_to_samples WHERE ncbi_taxon_id={} AND BodySite="{}" AND loaded_uid_num_QC1>10'.format(
        ncbi_taxon_id, site)
    phenotypes = connect.query(sql_barPlot)
    for phenotype in phenotypes:
        sql_disease_name = 'SELECT term FROM mesh_data WHERE uid="{}"'.format(phenotype['disease'])
        disease_name = connect.queryOne(sql_disease_name)['term']
        # linePlot
        sql_all_reiative_abundance = 'SELECT t1.relative_abundance ra FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
            site, phenotype['disease'], ncbi_taxon_id)
        all_reiative_abundance = connect.query(sql_all_reiative_abundance)
        box_value = []
        for relative_abundance in all_reiative_abundance:
            box_value.append(relative_abundance['ra'])
        # boxPlot
        if phenotype['disease'] == 'D006262':
            boxPlot_results.append({
                'type': 'box',
                'x': box_value,
                'name': disease_name,
                'boxpoints': 'Outliers',
                'width': 0.5,
                'fillcolor': 'green',
                'marker': {
                    'color': 'green'
                },
                'line': {
                    'color': 'green'
                }
            })
        else:
            boxPlot_results.append({
                'type': 'box',
                'x': box_value,
                'name': disease_name,
                'boxpoints': 'Outliers',
                'width': 0.5
            })
    results = {
        'boxPlot': boxPlot_results
    }
    return results


def getSpeciesBoxPlot(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    # 表格数据
    boxPlot_results = []
    sql_barPlot = 'SELECT disease, loaded_uid_num_QC1, relative_abundance_max FROM mbodymap_species_to_samples WHERE ncbi_taxon_id={} AND BodySite="{}" AND loaded_uid_num_QC1>10'.format(
        ncbi_taxon_id, site)
    phenotypes = connect.query(sql_barPlot)
    for phenotype in phenotypes:
        sql_disease_name = 'SELECT term FROM mesh_data WHERE uid="{}"'.format(phenotype['disease'])
        disease_name = connect.queryOne(sql_disease_name)['term']
        # linePlot
        sql_all_reiative_abundance = 'SELECT t1.relative_abundance ra FROM mbodymap_relative_species_abundances t1, mbodymap_samples t2, mbodymap_loaded_samples t3 WHERE t2.BodySite="{}" AND t2.disease="{}" AND t2.run_id=t3.accession_id AND t1.loaded_uid=t3.uid AND t1.ncbi_taxon_id={}'.format(
            site, phenotype['disease'], ncbi_taxon_id)
        all_reiative_abundance = connect.query(sql_all_reiative_abundance)
        box_value = []
        for relative_abundance in all_reiative_abundance:
            box_value.append(relative_abundance['ra'])
        # boxPlot
        if phenotype['disease'] == 'D006262':
            boxPlot_results.append({
                'type': 'box',
                'x': box_value,
                'name': disease_name,
                'boxpoints': 'Outliers',
                'width': 0.5,
                'fillcolor': 'green',
                'marker': {
                    'color': 'green'
                },
                'line': {
                    'color': 'green'
                }
            })
        else:
            boxPlot_results.append({
                'type': 'box',
                'x': box_value,
                'name': disease_name,
                'boxpoints': 'Outliers',
                'width': 0.5
            })
    results = {
        'boxPlot': boxPlot_results
    }
    return results


def getSpeciesIntro(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    sql_intro = 'SELECT name, disease_num, loaded_uid_num, have_health FROM mbodymap_species_to_site WHERE ncbi_taxon_id={} AND BodySite="{}"'.format(
        ncbi_taxon_id, site)
    # print(sql_intro)
    intro = connect.queryOne(sql_intro)



    # ===================================================================================================================
    # 菌在其他数据库中的链接 begin
    # db: mvp / gmrepo
    # ===================================================================================================================
    sql_mvp = 'select attributes from cross_db_links where ncbi_taxon_id={}'.format('ncbi_taxon_id')
    mvp = connect.queryOne(sql_mvp)
    sql_gmrepo = 'select * from gmrepo_species where ncbi_taxon_id={}'.format('ncbi_taxon_id')
    gmrepo = connect.queryOne(sql_gmrepo)
    sql_hmdad = 'select * from hmdad_taxon where ncbi_taxon_id={}'.format('ncbi_taxon_id')
    hmdad = connect.queryOne(sql_hmdad)
    # ============      菌在其他数据库中的链接 end      =====================================================================


    result = {
        'ncbi_taxon_id': ncbi_taxon_id,
        'name': intro['name'],
        'disease_num': int(intro['disease_num']),
        'loaded_uid_num': '{:,}'.format(int(intro['loaded_uid_num'])),
        'have_health': intro['have_health']
    }
    if mvp is not None:
        result['ifMvp'] = True
        result['mvpData'] = json.loads(mvp['attributes'])
    if gmrepo is not None:
        result['ifGmrepo'] = True
    if hmdad is not None:
        result['ifHmdad'] = True
    return result


def getGenusIntro(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    sql_intro = 'SELECT name, disease_num, loaded_uid_num, have_health FROM mbodymap_genus_to_site WHERE ncbi_taxon_id={} AND BodySite="{}"'.format(
        ncbi_taxon_id, site)
    intro = connect.queryOne(sql_intro)


    # ===================================================================================================================
    # 菌在其他数据库中的链接 begin
    # db: mvp / gmrepo
    # ===================================================================================================================
    sql_mvp = 'select attributes from cross_db_links where ncbi_taxon_id={}'.format('ncbi_taxon_id')
    mvp = connect.queryOne(sql_mvp)
    sql_gmrepo = 'select * from gmrepo_genus where ncbi_taxon_id={}'.format('ncbi_taxon_id')
    gmrepo = connect.queryOne(sql_gmrepo)
    sql_hmdad = 'select * from hmdad_taxon where ncbi_taxon_id={}'.format('ncbi_taxon_id')
    hmdad = connect.queryOne(sql_hmdad)
    # ============      菌在其他数据库中的链接 end      =====================================================================


    result = {
        'ncbi_taxon_id': ncbi_taxon_id,
        'name': intro['name'],
        'disease_num': int(intro['disease_num']),
        'loaded_uid_num': '{:,}'.format(int(intro['loaded_uid_num'])),
        'have_health': intro['have_health']
    }
    if mvp is not None:
        result['ifMvp'] = True
        result['mvpData'] = json.loads(mvp['attributes'])
    if gmrepo is not None:
        result['ifGmrepo'] = True
    if hmdad is not None:
        result['ifHmdad'] = True
    return result


def getMarkerSpecies(ncbi_taxon_id, site):
    connect = sqlConnect.MySQLConnection()
    #-----------    判断是否是marker ------------------------------
    sql_marker = 'select * from curated_lefse_analysis_results where ncbi_taxon_id={} limit 1'.format(
        ncbi_taxon_id)
    if_marker = connect.queryOne(sql_marker)
    if if_marker is None:
        marker = False
        return marker
    else:
    #-----------    是否在此部位有marker data  -------------------------
        sql_marker_data = 'select * from curated_lefse_analysis_results where ncbi_taxon_id={} and bodysite="{}" limit 1'.format(
        ncbi_taxon_id, site)
        if_marker_data = connect.queryOne(sql_marker_data)
        if if_marker_data is None:
            result = {
                'hasMarkerData': False
            }
            return result
        else:
    # ----------    marker intro    ---------------------------------
            sql_intro = 'select count(1) comparisons, count(distinct project_id) projects, count(distinct bodysite) bodysites, count(distinct phenotype2) unique_pheno from curated_lefse_analysis_results where ncbi_taxon_id={} and bodysite="{}"'.format(
                ncbi_taxon_id, site)
            intro = connect.queryOne(sql_intro)
    # ------------  marker table    --------------------------------
            sql_table = 'SELECT * FROM `curated_lefse_analysis_results` WHERE ncbi_taxon_id={} and bodysite="{}" ORDER BY lda DESC'.format(
                ncbi_taxon_id, site)
            table = connect.query(sql_table)
    # ------------- bar data    ----------------------------------------
    #         barData = []
    #         sql_pheno = 'select distinct phenotype1, phenotype1_name, phenotype2, phenotype2_name from curated_lefse_analysis_results WHERE ncbi_taxon_id={} and bodysite="{}"'.format(
    #             ncbi_taxon_id, site)
    #         phenotypes = connect.query(sql_pheno)
    #         for pheno in phenotypes:
    #             sql_phenoData = 'select project_id, bodysite, lda from curated_lefse_analysis_results WHERE ncbi_taxon_id={} and bodysite="{}" and phenotype1="{}" and phenotype2="{}"'.format(
    #                 ncbi_taxon_id, site, pheno['phenotype1'], pheno['phenotype2']
    #             )
    #             phenoData = connect.query(sql_phenoData)
    #             phenoBarData = []
    #             for item in phenoData:
    #                 phenoBarData.append({
    #                     'project': item['project_id'] + '    ' + item['bodysite'],
    #                     'score': item['lda']
    #                 })
    #             barData.append({
    #                 'phenotype1': pheno['phenotype1'],
    #                 'phenotype2': pheno['phenotype2'],
    #                 'phenotype1_name': pheno['phenotype1_name'],
    #                 'phenotype2_name': pheno['phenotype2_name'],
    #                 'phenoBarData': phenoBarData
    #             })

        # --------------    result  -----------------------------------
            result = {
                'hasMarkerData': True,
                'intro': intro,
                'table': table
                # 'bar': barData
            }
            return result
