|
| 1 | +import collections |
1 | 2 | from typing import Hashable, Optional
|
2 | 3 |
|
3 | 4 | import dask.array as da
|
4 | 5 | import numpy as np
|
5 | 6 | from numba import guvectorize
|
6 | 7 | from xarray import Dataset
|
7 | 8 |
|
| 9 | +from sgkit import to_haplotype_calls |
8 | 10 | from sgkit.stats.utils import assert_array_shape
|
9 | 11 | from sgkit.typing import ArrayLike
|
10 |
| -from sgkit.utils import conditional_merge_datasets, define_variable_if_absent |
| 12 | +from sgkit.utils import ( |
| 13 | + conditional_merge_datasets, |
| 14 | + define_variable_if_absent, |
| 15 | + hash_columns, |
| 16 | +) |
11 | 17 | from sgkit.window import has_windows, window_statistic
|
12 | 18 |
|
13 | 19 | from .. import variables
|
@@ -682,3 +688,190 @@ def pbs(
|
682 | 688 | {variables.stat_pbs: (["windows", "cohorts_0", "cohorts_1", "cohorts_2"], p)}
|
683 | 689 | )
|
684 | 690 | return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
|
| 691 | + |
| 692 | + |
| 693 | +N_GARUD_H_STATS = 4 # H1, H12, H123, H2/H1 |
| 694 | + |
| 695 | + |
| 696 | +def _Garud_h(k: ArrayLike) -> ArrayLike: |
| 697 | + # find haplotype counts (sorted in descending order) |
| 698 | + counts = sorted(collections.Counter(k.tolist()).values(), reverse=True) |
| 699 | + counts = np.array(counts) |
| 700 | + |
| 701 | + # find haplotype frequencies |
| 702 | + n = k.shape[0] |
| 703 | + f = counts / n |
| 704 | + |
| 705 | + # compute H1 |
| 706 | + h1 = np.sum(f ** 2) |
| 707 | + |
| 708 | + # compute H12 |
| 709 | + h12 = np.sum(f[:2]) ** 2 + np.sum(f[2:] ** 2) |
| 710 | + |
| 711 | + # compute H123 |
| 712 | + h123 = np.sum(f[:3]) ** 2 + np.sum(f[3:] ** 2) |
| 713 | + |
| 714 | + # compute H2/H1 |
| 715 | + h2 = h1 - f[0] ** 2 |
| 716 | + h2_h1 = h2 / h1 |
| 717 | + |
| 718 | + return np.array([h1, h12, h123, h2_h1]) |
| 719 | + |
| 720 | + |
| 721 | +def _Garud_h_cohorts( |
| 722 | + ht: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int |
| 723 | +) -> ArrayLike: |
| 724 | + k = hash_columns(ht) # hash haplotypes |
| 725 | + arr = np.empty((n_cohorts, N_GARUD_H_STATS)) |
| 726 | + for c in range(n_cohorts): |
| 727 | + arr[c, :] = _Garud_h(k[sample_cohort == c]) |
| 728 | + return arr |
| 729 | + |
| 730 | + |
| 731 | +def Garud_h( |
| 732 | + ds: Dataset, |
| 733 | + *, |
| 734 | + call_haplotype: Hashable = variables.call_haplotype, |
| 735 | + merge: bool = True, |
| 736 | +) -> Dataset: |
| 737 | + """Compute the H1, H12, H123 and H2/H1 statistics for detecting signatures |
| 738 | + of soft sweeps, as defined in Garud et al. (2015). |
| 739 | +
|
| 740 | + By default, values of this statistic are calculated across all variants. |
| 741 | + To compute values in windows, call :func:`window` before calling |
| 742 | + this function. |
| 743 | +
|
| 744 | + Parameters |
| 745 | + ---------- |
| 746 | + ds |
| 747 | + Genotype call dataset. |
| 748 | + call_haplotype |
| 749 | + Call haplotype variable to use or calculate. Defined by |
| 750 | + :data:`sgkit.variables.call_haplotype_spec`. |
| 751 | + If the variable is not present in ``ds``, it will be computed |
| 752 | + using :func:`to_haplotype_calls`. |
| 753 | + merge |
| 754 | + If True (the default), merge the input dataset and the computed |
| 755 | + output variables into a single dataset, otherwise return only |
| 756 | + the computed output variables. |
| 757 | + See :ref:`dataset_merge` for more details. |
| 758 | +
|
| 759 | + Returns |
| 760 | + ------- |
| 761 | + A dataset containing the following variables: |
| 762 | +
|
| 763 | + - `stat_Garud_h1` (windows, cohorts): Garud H1 statistic. |
| 764 | + Defined by :data:`sgkit.variables.stat_Garud_h1_spec`. |
| 765 | +
|
| 766 | + - `stat_Garud_h12` (windows, cohorts): Garud H12 statistic. |
| 767 | + Defined by :data:`sgkit.variables.stat_Garud_h12_spec`. |
| 768 | +
|
| 769 | + - `stat_Garud_h123` (windows, cohorts): Garud H123 statistic. |
| 770 | + Defined by :data:`sgkit.variables.stat_Garud_h123_spec`. |
| 771 | +
|
| 772 | + - `stat_Garud_h2_h1` (windows, cohorts): Garud H2/H1 statistic. |
| 773 | + Defined by :data:`sgkit.variables.stat_Garud_h2_h1_spec`. |
| 774 | +
|
| 775 | + Raises |
| 776 | + ------ |
| 777 | + NotImplementedError |
| 778 | + If the dataset is not diploid. |
| 779 | +
|
| 780 | + Warnings |
| 781 | + -------- |
| 782 | + This function is currently only implemented for diploid datasets. |
| 783 | +
|
| 784 | + Examples |
| 785 | + -------- |
| 786 | +
|
| 787 | + >>> import numpy as np |
| 788 | + >>> import sgkit as sg |
| 789 | + >>> import xarray as xr |
| 790 | + >>> ds = sg.simulate_genotype_call_dataset(n_variant=5, n_sample=4) |
| 791 | +
|
| 792 | + >>> # Divide samples into two cohorts |
| 793 | + >>> sample_cohort = np.repeat([0, 1], ds.dims["samples"] // 2) |
| 794 | + >>> ds["sample_cohort"] = xr.DataArray(sample_cohort, dims="samples") |
| 795 | +
|
| 796 | + >>> # Divide into windows of size three (variants) |
| 797 | + >>> ds = sg.window(ds, size=3, step=3) |
| 798 | +
|
| 799 | + >>> gh = sg.Garud_h(ds) |
| 800 | + >>> gh["stat_Garud_h1"].values # doctest: +NORMALIZE_WHITESPACE |
| 801 | + array([[0.25 , 0.375], |
| 802 | + [0.375, 0.375]]) |
| 803 | + >>> gh["stat_Garud_h12"].values # doctest: +NORMALIZE_WHITESPACE |
| 804 | + array([[0.375, 0.625], |
| 805 | + [0.625, 0.625]]) |
| 806 | + >>> gh["stat_Garud_h123"].values # doctest: +NORMALIZE_WHITESPACE |
| 807 | + array([[0.625, 1. ], |
| 808 | + [1. , 1. ]]) |
| 809 | + >>> gh["stat_Garud_h2_h1"].values # doctest: +NORMALIZE_WHITESPACE |
| 810 | + array([[0.75 , 0.33333333], |
| 811 | + [0.33333333, 0.33333333]]) |
| 812 | + """ |
| 813 | + |
| 814 | + if ds.dims["ploidy"] != 2: |
| 815 | + raise NotImplementedError("Garud H only implemented for diploid genotypes") |
| 816 | + |
| 817 | + ds = define_variable_if_absent( |
| 818 | + ds, variables.call_haplotype, call_haplotype, to_haplotype_calls |
| 819 | + ) |
| 820 | + variables.validate(ds, {call_haplotype: variables.call_haplotype_spec}) |
| 821 | + |
| 822 | + ht = ds[call_haplotype] |
| 823 | + |
| 824 | + # convert sample cohorts to haplotype layout |
| 825 | + sc = ds.sample_cohort.values |
| 826 | + hsc = np.stack((sc, sc), axis=1).ravel() # TODO: assumes diploid |
| 827 | + n_cohorts = sc.max() + 1 # 0-based indexing |
| 828 | + |
| 829 | + if has_windows(ds): |
| 830 | + gh = window_statistic( |
| 831 | + ht, |
| 832 | + lambda ht: _Garud_h_cohorts(ht, hsc, n_cohorts), |
| 833 | + ds.window_start.values, |
| 834 | + ds.window_stop.values, |
| 835 | + dtype=np.float64, |
| 836 | + # first chunks dimension is windows, computed in window_statistic |
| 837 | + chunks=(-1, n_cohorts, N_GARUD_H_STATS), |
| 838 | + new_axis=2, # 2d -> 3d |
| 839 | + ) |
| 840 | + n_windows = ds.window_start.shape[0] |
| 841 | + assert_array_shape(gh, n_windows, n_cohorts, N_GARUD_H_STATS) |
| 842 | + new_ds = Dataset( |
| 843 | + { |
| 844 | + variables.stat_Garud_h1: ( |
| 845 | + ("windows", "cohorts"), |
| 846 | + gh[:, :, 0], |
| 847 | + ), |
| 848 | + variables.stat_Garud_h12: ( |
| 849 | + ("windows", "cohorts"), |
| 850 | + gh[:, :, 1], |
| 851 | + ), |
| 852 | + variables.stat_Garud_h123: ( |
| 853 | + ("windows", "cohorts"), |
| 854 | + gh[:, :, 2], |
| 855 | + ), |
| 856 | + variables.stat_Garud_h2_h1: ( |
| 857 | + ("windows", "cohorts"), |
| 858 | + gh[:, :, 3], |
| 859 | + ), |
| 860 | + } |
| 861 | + ) |
| 862 | + else: |
| 863 | + # TODO: note this materializes all the data, so windowless should be discouraged/not supported |
| 864 | + ht = ht.values |
| 865 | + |
| 866 | + gh = _Garud_h_cohorts(ht, sample_cohort=hsc, n_cohorts=n_cohorts) |
| 867 | + assert_array_shape(gh, n_cohorts, N_GARUD_H_STATS) |
| 868 | + |
| 869 | + new_ds = Dataset( |
| 870 | + { |
| 871 | + variables.stat_Garud_h1: gh[:, 0], |
| 872 | + variables.stat_Garud_h12: gh[:, 1], |
| 873 | + variables.stat_Garud_h123: gh[:, 2], |
| 874 | + variables.stat_Garud_h2_h1: gh[:, 3], |
| 875 | + } |
| 876 | + ) |
| 877 | + return conditional_merge_datasets(ds, variables.validate(new_ds), merge) |
0 commit comments