-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalyse_metagenomique.py
1841 lines (1520 loc) · 102 KB
/
analyse_metagenomique.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import sys
import ast
import random
import pandas as pd
import skbio
import seaborn as sns
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import scikit_posthocs as sp
import skbio.diversity.alpha as alpha
from skbio import TreeNode, DistanceMatrix
from skbio.tree import nj
from skbio.diversity import beta_diversity
from skbio.stats.ordination import pcoa
from skbio.stats.distance import permanova
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.manifold import MDS
from matplotlib import cm
from matplotlib.patches import Ellipse
from matplotlib_venn import venn2, venn3
from scipy.stats import f_oneway, shapiro, mannwhitneyu, kruskal, spearmanr, pearsonr
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from sklearn import preprocessing
from pydeseq2.preprocessing import deseq2_norm
def one_condition_struct(asv_info, condition): # Function to create a structure containing different conditions with associated samples
# Check if a condition file is provided
if condition is None:
print("Error: Please provide a condition file.")
return
print("Name of conditions available :")
available_condition = list(condition.columns)
for name in available_condition:
print(name)
column_sample_name = input("What is the name of the column with the sample names? : ")
column_condition_name = input("What is the name of the column with the names of the conditions to which the samples belong? : ")
# Create a dictionary to store the conditions and their associated samples
conditions = {}
for sample, cond in zip(condition[column_sample_name], condition[column_condition_name]):
matching_samples = [col for col in asv_info.columns if col.startswith(sample)] # Find columns in asv_info DataFrame that match the sample name
if not matching_samples:
print(f"No matching samples found for {sample}. Skipping.")
continue
# If the condition is not already in the dictionary, add it with an empty list : allows conditions to be written only once
if cond not in conditions:
conditions[cond] = []
# Extend the list of samples associated with the condition
conditions[cond].extend(matching_samples)
return conditions
def two_conditions_struct(asv_info, condition): # Function to create a structure containing different conditions with associated samples
# Check if a condition file is provided
if condition is None:
print("Error: Please provide a condition file.")
return
print("Name of conditions available :")
available_condition = list(condition.columns)
for name in available_condition:
print(name)
column_sample_name = input("What is the name of the column with the sample names? : ")
column_condition_first_name = input("What is the name of the column with the names of the first condition to which the samples belong? : ")
column_condition_second_name = input("What is the name of the column with the names of the second condition to which the samples belong? : ")
# Create a dictionary to store the conditions and their associated samples
conditions = {}
for sample, cond1, cond2 in zip(condition[column_sample_name], condition[column_condition_first_name], condition[column_condition_second_name]):
matching_samples = [col for col in asv_info.columns if col.startswith(sample)] # Find columns in asv_info DataFrame that match the sample name
if not matching_samples:
print(f"No matching samples found for {sample}. Skipping.")
continue
# If conditions are not already in the dictionary, add them with an empty list : allows conditions to be written only once
if (cond1, cond2) not in conditions:
conditions[(cond1, cond2)] = []
# Extend the list of samples associated with the condition
conditions[(cond1, cond2)].extend(matching_samples)
print(conditions)
return conditions, column_condition_first_name, column_condition_second_name
#####################
## Alpha Diversity ##
#####################
def alpha_diversity_one(asv_info): # Function to calculate alpha diversity for a given sample
print("Available samples :")
for column in asv_info.columns[1:]:
print(column)
sample_alpha = input("Which sample do you want alpha diversity ? : ")
counts = asv_info[sample_alpha] # Retrieve counts for the selected sample
alpha_index = input("Which alpha diversity index do you want to calculate ? (shannon / simpson / inverse_simpson / chao / richness): ")
if alpha_index == 'shannon':
alpha_diversity = skbio.diversity.alpha.shannon(counts)
print("-- Alpha Diversity : Shannon index for sample ", sample_alpha, " : ", alpha_diversity)
elif alpha_index == 'simpson':
alpha_diversity = skbio.diversity.alpha.simpson(counts)
print("-- Alpha Diversity : Simpson index for sample ", sample_alpha, " : ", alpha_diversity)
elif alpha_index == 'inverse_simpson':
simpson_index = skbio.diversity.alpha.simpson(counts)
alpha_diversity = 1 / simpson_index if simpson_index != 0 else float('inf') # Calculate the inverse Simpson index
print("-- Alpha Diversity : Inverse Simpson index for sample ", sample_alpha, " : ", alpha_diversity)
elif alpha_index == 'chao':
alpha_diversity = skbio.diversity.alpha.chao1(counts)
print("-- Alpha Diversity : Chao index for sample ", sample_alpha, " : ", alpha_diversity)
elif alpha_index == 'richness':
asv_sample = asv_info[['ASVNumber', sample_alpha]].loc[asv_info[sample_alpha] > 0]
asv_names = asv_sample['ASVNumber'].tolist()
asv_number = len(asv_names)
print("-- Alpha Diversity : Observed richness for sample ",sample_alpha, " : ", asv_number)
else:
print("Alpha diversity index not supported.")
exit()
def alpha_diversity_all(asv_info): # Function to calculate alpha diversity for all samples
alpha_diversity_all = {}
alpha_index = input("Which alpha diversity index do you want to calculate ? (shannon / simpson / inverse_simpson / chao / richness): ")
if alpha_index == 'shannon':
print("-- Shannon Alpha diversity for all samples --")
for column in asv_info.columns[1:]: # Ignore the first column (ASVNumber) to have samples columns
counts = asv_info[column] # Retrieve counts for the selected sample
alpha_diversity_all[column] = {}
alpha_diversity_all[column]['shannon'] = skbio.diversity.alpha.shannon(counts)
for column, diversity in alpha_diversity_all.items():
print(column, " : ", diversity['shannon'])
elif alpha_index == 'simpson':
print("-- Simpson Alpha diversity for all samples --")
for column in asv_info.columns[1:]:
counts = asv_info[column]
alpha_diversity_all[column] = {}
alpha_diversity_all[column]['simpson'] = skbio.diversity.alpha.simpson(counts)
for column, diversity in alpha_diversity_all.items():
print(column, " : ", diversity['simpson'])
elif alpha_index == 'inverse_simpson':
print("-- Inverse Simpson Alpha diversity for all samples --")
for column in asv_info.columns[1:]:
counts = asv_info[column]
alpha_diversity_all[column] = {}
simpson_index = skbio.diversity.alpha.simpson(counts)
inverse_simpson_index = 1 / simpson_index if simpson_index != 0 else float('inf') # Calculate the inverse Simpson index
alpha_diversity_all[column]['inverse_simpson'] = inverse_simpson_index
for column, diversity in alpha_diversity_all.items():
print(column, " : ", diversity['inverse_simpson'])
elif alpha_index == 'chao':
print("-- Chao Alpha diversity for all samples --")
for column in asv_info.columns[1:]:
counts = asv_info[column]
alpha_diversity_all[column] = {}
alpha_diversity_all[column]['chao'] = skbio.diversity.alpha.chao1(counts)
for column, diversity in alpha_diversity_all.items():
print(column, " : ", diversity['chao'])
elif alpha_index == 'richness':
print("-- Observed Richness Alpha diversity for all samples --")
asv_all_samples = {}
for column in asv_info.columns[1:]:
asv_samples = asv_info[['ASVNumber', column]].loc[asv_info[column] > 0] # Keep only rows of the df where the value in the current column is greater than 0
asv_names = asv_samples['ASVNumber'].tolist()
asv_all_samples[column] = len(asv_names) # Number of richness ASVs for the current sample
for sample, asv in asv_all_samples.items():
print(sample, " : ", asv)
else:
print("Alpha diversity index not supported.")
exit()
def statistical_test_alpha(asv_info, condition): # Function to perform statistical tests on alpha diversity
# Retrieve the structure containing sample data based on conditions
conditions=one_condition_struct(asv_info, condition)
# Calculate alpha diversity for each condition
alpha_results = {}
alpha_index = input("On which alpha diversity index would you like to perform a statistical test ? (shannon / simpson / inverse_simpson / chao / richness): ")
for cond, samples in conditions.items():
alpha_results[cond] = []
for sample in samples:
counts = asv_info[sample]
if alpha_index == 'shannon':
alpha_results[cond].append(skbio.diversity.alpha.shannon(counts))
elif alpha_index == 'simpson':
alpha_results[cond].append(skbio.diversity.alpha.simpson(counts))
elif alpha_index == 'inverse_simpson':
simpson_index = skbio.diversity.alpha.simpson(counts)
alpha_results[cond].append(1 / simpson_index if simpson_index != 0 else float('inf'))
elif alpha_index == 'chao':
alpha_results[cond].append(skbio.diversity.alpha.chao1(counts))
elif alpha_index == 'richness':
asv_samples = asv_info[['ASVNumber', sample]].loc[asv_info[sample] > 0]
asv_names = asv_samples['ASVNumber'].tolist()
alpha_results[cond].append(len(asv_names))
print(alpha_results)
alpha_values = []
for values in alpha_results.values(): # Boucle sur les valeurs de alpha_results
alpha_values.extend(values)
s_statistic, p_value = shapiro(alpha_values)
print("Shapiro-Wilk test result for normality (alpha diversity) : ")
print("Statistic for the test : ", s_statistic)
print("p-value : ", p_value)
threshold_test = input("Threshold for test : ")
if p_value > float(threshold_test):
print(" Alpha diversity values comes from a normal distribution : ")
test_para = input("Do you want to make ANOVA test or Pearson correlation test ? (anova/pearson) : ")
if test_para == 'anova':
alpha_values_array = [np.array(values) for values in alpha_results.values()]
print(alpha_values_array)
# ANOVA test
f_statistic, p_value = f_oneway(*alpha_values_array) # One-way ANOVA test
print("ANOVA test result between all groups : ")
print("Statistic F : ", f_statistic)
print("p-value : ", p_value)
post_hoc = input("Do you want to make post-hoc test (Tukey) ? Yes/No : ") # 95% test
if post_hoc == 'Yes':
alpha_values_flat = np.concatenate(alpha_values_array) # Flattening the data
conditions_list = [] # Creating a list of conditions for Tukey test
for cond, data in alpha_results.items():
for value in data:
conditions_list.append(cond) # Add conditions for each sample
tukey_result = pairwise_tukeyhsd(endog=alpha_values_flat, groups=conditions_list)
print("Result of Tukey test : ")
print(tukey_result)
elif test_para == 'pearson':
# Pearman correlation test
print("Name of groups available :")
for cond, samples in conditions.items():
print(cond)
first_group = input("Name of first group you want to compare samples : ")
second_group = input("Name of second group you want to compare samples : ")
first_array = [values for values in alpha_results[first_group]]
second_array = [values for values in alpha_results[second_group]]
if len(first_array) != len(second_array): # All the input array must have the same length.
min_length = min(len(first_array), len(second_array)) # Crop the array to the min length array
first_array = first_array[:min_length]
second_array = second_array[:min_length]
corr_statistic, p_value = pearsonr(first_array, second_array)
print("Pearson correlation test result between these two groups : ")
print("Correlation coefficient : ", corr_statistic)
print("p-value : ", p_value)
else:
print("Statistical test not supported.")
else:
print(" Alpha diversity values does not come from a normal distribution : ")
test_non_para = input("Do you want to make Kruskal-Wallis test or Mann-Whitney test or Spearman correlation test ? (kruskal / mann / spearman) : ")
if test_non_para == 'kruskal':
alpha_values_array = [np.array(values) for values in alpha_results.values()]
# Kruskal-Wallis test
h_statistic, p_value = kruskal(*alpha_values_array)
print("Kruskal-Wallis test result between all groups : ")
print("Statistic H : ", h_statistic)
print("p-value : ", p_value)
post_hoc = input("Do you want to make post-hoc test (Dunn) ? Yes/No : ")
if post_hoc == 'Yes':
alpha_values_flat = np.concatenate(alpha_values_array) # Flattening the data
conditions_list = [] # Creating a list of conditions for Tukey test
for cond, data in alpha_results.items():
for value in data:
conditions_list.append(cond) # Add conditions for each sample
df = pd.DataFrame({'Values': alpha_values_flat,'Groups': conditions_list})
dunn_result = sp.posthoc_dunn(df, val_col='Values', group_col='Groups')
print("Result of Dunn test : ")
print(dunn_result)
elif test_non_para == 'mann':
print("Name of groups available :")
for cond, samples in conditions.items():
print(cond)
first_group = input("Name of first group you want to compare samples : ")
second_group = input("Name of second group you want to compare samples : ")
first_array = [values for values in alpha_results[first_group]]
second_array = [values for values in alpha_results[second_group]]
mwu_statistic, p_value = mannwhitneyu(first_array, second_array) # Mann-Whitney U test
print("Mann-Whitney U test result between these two groups : ")
print("Statistic for the test : ", mwu_statistic)
print("p-value : ", p_value)
elif test_non_para == 'spearman':
# Spearman correlation test
print("Name of groups available :")
for cond, samples in conditions.items():
print(cond)
first_group = input("Name of first group you want to compare samples : ")
second_group = input("Name of second group you want to compare samples : ")
first_array = [values for values in alpha_results[first_group]]
second_array = [values for values in alpha_results[second_group]]
if len(first_array) != len(second_array): # All the input array dimensions must match exactly
min_length = min(len(first_array), len(second_array)) # Crop the array to the min length array
first_array = first_array[:min_length]
second_array = second_array[:min_length]
corr_statistic, p_value = spearmanr(first_array, second_array)
print("Spearman correlation test result between these two groups : ")
print("Correlation coefficient : ", corr_statistic)
print("p-value : ", p_value)
else:
print("Statistical test not supported.")
def alpha_graph(asv_info): # Function to create alpha diversity scatter plot
alpha_index = input("Which alpha diversity index would you like ? (shannon / simpson / inverse_simpson / chao / richness) : ")
alpha_diversity = {} # Dictionary to store alpha diversity values for each sample and index
for column in asv_info.columns[1:]:
counts = asv_info[column] # Extract counts for the sample
if alpha_index == 'shannon':
alpha_diversity.setdefault('shannon', {})[column] = skbio.diversity.alpha.shannon(counts)
elif alpha_index == 'simpson':
alpha_diversity.setdefault('simpson', {})[column] = skbio.diversity.alpha.simpson(counts)
elif alpha_index == 'inverse_simpson':
simpson_index = skbio.diversity.alpha.simpson(counts)
alpha_diversity.setdefault('inverse_simpson', {})[column] = 1 / simpson_index if simpson_index != 0 else float('inf')
elif alpha_index == 'chao':
alpha_diversity.setdefault('chao', {})[column] = skbio.diversity.alpha.chao1(counts)
elif alpha_index == 'richness':
asv_samples = asv_info[['ASVNumber', column]].loc[asv_info[column] > 0] # Keep only rows of the df where the value in the current column is greater than 0
asv_names = asv_samples['ASVNumber'].tolist()
alpha_diversity.setdefault('richness', {})[column] = len(asv_names) # Number of richness ASVs for the current sample
else:
print("Alpha diversity index not supported.")
exit()
# Create the figure
plt.figure(figsize=(12, 10))
plt.scatter(alpha_diversity[alpha_index].keys(), alpha_diversity[alpha_index].values(), color='mediumturquoise')
plt.xlabel('Samples')
plt.ylabel(f'{alpha_index.capitalize()} Alpha Diversity')
plt.title(f'{alpha_index.capitalize()} Alpha Diversity according to samples')
plt.xticks(rotation=90)
plt.tight_layout()
format_file = input("Which file format would you like to save the plot ? SVG / PDF / PNG ")
if format_file == 'PDF':
plt.savefig("scatter_plot_alpha_div.pdf", format='pdf', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("scatter_plot_alpha_div.svg", format='svg', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("scatter_plot_alpha_div.png", format='png', pad_inches=0.2)
plt.show()
def alpha_graph_condition(asv_info, condition): # Function to create alpha diversity boxplots grouped by conditions
alpha_index = input("Which alpha diversity index would you like? (shannon / simpson / inverse_simpson / chao / richness): ")
# Retrieve the structure containing sample data based on conditions
conditions=one_condition_struct(asv_info, condition)
# Calculate alpha diversity for each condition
alpha_results = {}
for cond, samples in conditions.items():
alpha_results[cond] = []
for sample in samples:
counts = asv_info[sample]
if alpha_index == 'shannon':
alpha_results[cond].append(skbio.diversity.alpha.shannon(counts))
elif alpha_index == 'simpson':
alpha_results[cond].append(skbio.diversity.alpha.simpson(counts))
elif alpha_index == 'inverse_simpson':
simpson_index = skbio.diversity.alpha.simpson(counts)
alpha_results[cond].append(1 / simpson_index if simpson_index != 0 else float('inf'))
elif alpha_index == 'chao':
alpha_results[cond].append(skbio.diversity.alpha.chao1(counts))
elif alpha_index == 'richness':
asv_samples = asv_info[['ASVNumber', sample]].loc[asv_info[sample] > 0]
asv_names = asv_samples['ASVNumber'].tolist()
alpha_results[cond].append(len(asv_names))
else:
print("Alpha diversity index not supported.")
exit()
# Set the color palette
colors = sns.color_palette('rainbow', n_colors=len(alpha_results))
# Boxplots for alpha diversity grouped by conditions
plt.figure(figsize=(10, 6))
bp = plt.boxplot(alpha_results.values(), labels=alpha_results.keys(), patch_artist=True, capprops={'linewidth': 0.0})
plt.xlabel('Conditions')
plt.ylabel(f'{alpha_index.capitalize()} Alpha Diversity')
plt.title(f'{alpha_index.capitalize()} Alpha Diversity by Condition')
plt.xticks(rotation=90, ha='right')
# Customize boxplot and whisker colors and median line
for i, (box, median, color) in enumerate(zip(bp['boxes'], bp['medians'], colors)):
box.set(facecolor=color + (0.2,), edgecolor=color) # Set box color
median.set(color=color) # Set median line color
bp['whiskers'][i * 2].set(color=color) # Lower whisker
bp['whiskers'][(i * 2) + 1].set(color=color) # Upper whisker
# Add points for each value
for i, cond in enumerate(alpha_results.keys()):
y = alpha_results[cond]
x = [i + 1] * len(y)
plt.plot(x, y, 'k.', alpha=0.9, markersize=9, color=colors[i])
plt.xticks(ticks=np.arange(1, len(alpha_results) + 1), labels=alpha_results.keys(), rotation=90, ha='right')
plt.tight_layout()
format_file = input("Which file format would you like to save the plot ? SVG / PDF / PNG ")
if format_file == 'PDF':
plt.savefig("boxplot_alpha_div.pdf", format='pdf', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("boxplot_alpha_div.svg", format='svg', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("boxplot_alpha_div.png", format='png', pad_inches=0.2)
plt.show()
####################
## Beta Diversity ##
####################
def beta_diversity_function(asv_info): # Function to calculate beta diversity
asv_info.index = asv_info['ASVNumber']
asv_info = asv_info.drop(columns=['ASVNumber'])
asv_info = asv_info.T.fillna(0) # Transpose so that samples are rows
beta_index = input("Which alpha diversity indices would you like ? : (braycurtis / jaccard / weighted_unifrac / unweighted_unifrac)")
if beta_index in ['braycurtis', 'jaccard']:
# Calculate beta diversity using the selected index and count table
beta_diversity = skbio.diversity.beta_diversity(beta_index, asv_info)
elif beta_index in ['weighted_unifrac', 'unweighted_unifrac']:
# Load phylogenetic tree
tree = TreeNode.read('./phyloTree.newick')
midpoint_rooted_tree = tree.root_at_midpoint() # Set root to the tree
# Function to replace full name with short identifier
def update_node_names(node):
if node.name:
node.name = node.name.split('|')[0] # Extract short identifier (before first '|')
for child in node.children: # Call recursively on children
update_node_names(child)
# Update node names
update_node_names(midpoint_rooted_tree)
tree_tip_names = {tip.name for tip in midpoint_rooted_tree.tips()} # Get the names of the tips (leaves) of the tree
otu_ids = set(asv_info.columns) # Obtain the correspondence between the names of the ASVs in the table and the names of the leaves of the tree
# Check matches
missing_in_tree = otu_ids - tree_tip_names
missing_in_otu_table = tree_tip_names - otu_ids
print(f"OTUs present in the table but missing in the tree : {missing_in_tree}")
print(f"Tips present in the tree but missing in the table : {missing_in_otu_table}")
# Calculate UniFrac distances
beta_diversity = skbio.diversity.beta_diversity(
metric=beta_index,
counts=asv_info.values,
ids=asv_info.index,
otu_ids=asv_info.columns,
tree=midpoint_rooted_tree
)
return beta_diversity
def beta_diversity_all(beta_diversity): # Function to display beta diversity for all samples
print("-- Beta diversity distance matrix :")
print(beta_diversity)
def statistical_test_beta(beta_diversity, asv_info, condition): # Function to perform a statistical test (permanova) on beta diversity
# Retrieve the structure containing sample data based on conditions
conditions=one_condition_struct(asv_info, condition)
conditions_list = []
# Extract groups from sample
for cond, samples in conditions.items():
for sample in samples:
conditions_list.append(cond) # Add conditions for each sample
test_diff = input("Do you want to make permanova test on all conditions or Pairwise permanova tests ? (permanova / pairwise) : ")
if test_diff == 'permanova':
permanova_result = permanova(distance_matrix=beta_diversity, grouping=conditions_list)
print("-- Permanova test results")
print("Permanova statistic : ", permanova_result['test statistic'])
print("p-value : ", permanova_result['p-value'])
print("Number of permutations : ", permanova_result['number of permutations'])
elif test_diff == 'pairwise':
unique_cond = set(conditions_list)
pairwise_results = []
for cond1 in unique_cond:
for cond2 in unique_cond:
if cond1 < cond2:
# Filter samples which belong to cond1 et cond2
indices = [i for i, cond in enumerate(conditions_list) if cond in [cond1, cond2]]
ids = [beta_diversity.ids[i] for i in indices]
sub_matrix = beta_diversity.filter(ids)
sub_groups = [conditions_list[i] for i in indices]
# Pairwise permanova test
result = permanova(distance_matrix=sub_matrix, grouping=sub_groups)
pairwise_results.append({
'Group1': cond1,
'Group2': cond2,
'Permanova_statistic': result['test statistic'],
'p-value': result['p-value'],
'Number_of_permutations': result['number of permutations']
})
# Convert into dataframe
pairwise_results_df = pd.DataFrame(pairwise_results)
print("-- Pairwise Permanova test results")
print(pairwise_results_df)
else:
print("Statistical test not supported.")
def beta_diversity_graph(beta_diversity, condition, asv_info): # Function to create beta diversity visualizations
beta_representation = input("Which beta diversity representation would you like ? (heatmap / NMDS / NMDS_grouped_by_condition / PCoA / PCoA_grouped_by_condition): ")
if beta_representation == 'heatmap': # Create a heatmap representation of beta diversity
plt.figure(figsize=(10, 8))
sns.heatmap(beta_diversity.to_data_frame(), cmap="rainbow_r", annot=True, fmt=".2f", linewidths=.5)
plt.title("Beta diversity heatmap")
plt.xlabel("Samples")
plt.ylabel("Samples")
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG ")
if format_file == 'PDF':
plt.savefig("heatmap_beta_diversity.pdf", format='pdf', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("heatmap_beta_diversity.svg", format='svg', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("heatmap_beta_diversity.png", format='png', pad_inches=0.2)
plt.show()
elif beta_representation == 'NMDS':
mds = MDS(metric=False, random_state=0) # MDS object which perform nonmetric MDS with reproducibility
beta_diversity_array = np.array(beta_diversity[:][:])
mds_results = mds.fit_transform(beta_diversity_array) # Performs MDS transformation
stress = mds.stress_
print(stress)
if 0.1 <= stress <= 0.2:
mds = MDS(n_components=3, metric=False, random_state=0) # MDS with 3 dimensions
mds_results = mds.fit_transform(beta_diversity_array) # Performs MDS transformation
stress = mds.stress_
print(stress)
if stress <= 0.2:
colors = sns.color_palette('rainbow', n_colors=len(beta_diversity_array)) # Generates a color for each sample
mds_results_df = pd.DataFrame(mds_results, columns=['Dimension 1', 'Dimension 2']) # Converts the MDS results into a DataFrame
sample_names = list(asv_info.columns[1:])
mds_results_df.index = sample_names # Set the index of the DataFrame to the sample names
# Create the figure with a legend
legend_handles = []
plt.figure(figsize=(12, 8))
for sample, color in zip(sample_names, colors): # Iterate over sample names and corresponding colors
# Retrieve the x and y coordinates for the current sample from a DataFrame
x = mds_results_df.loc[sample, 'Dimension 1']
y = mds_results_df.loc[sample, 'Dimension 2']
handle = plt.scatter(x, y, color=color, label=sample) # Plot the current sample with specified color and label
legend_handles.append(handle) # Append the handle to the legend_handles list for creating the legend later
plt.legend(handles=legend_handles, title='Samples', loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.xlabel('NMDS 1')
plt.ylabel('NMDS 2')
plt.title('NMDS Plot')
plt.annotate(f'Stress: {stress:.4f}', xy=(0.83, -0.06), xycoords='axes fraction', fontsize=12,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5)) # Add an annotation for the stress value
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG ")
if format_file == 'PDF':
plt.savefig("NMDS.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("NMDS.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("NMDS.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
else :
print("The stress variable is ", stress, ". It's greater than 0.2. Perform a PCoA analysis instead.")
pcoa_results = pcoa(beta_diversity) # Perform PCoA
sample_names = list(asv_info.columns[1:]) # Extract sample name
pcoa_results.samples.index = sample_names # Set the IDs of PCoA results samples to the sample names
colors = sns.color_palette('rainbow', n_colors=len(sample_names)) # Generate a list of colors
legend_handles = []
plt.figure(figsize=(12, 8))
for sample, color in zip(sample_names, colors): # Iterate over sample names and corresponding colors
x = pcoa_results.samples.loc[sample, 'PC1']
y = pcoa_results.samples.loc[sample, 'PC2']
handle = plt.scatter(x, y, color=color, label=sample)
legend_handles.append(handle)
plt.legend(handles=legend_handles, title='Samples', loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.title('PCoA Plot')
plt.xlabel('Axis 1')
plt.ylabel('Axis 2')
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG")
if format_file == 'PDF':
plt.savefig("PCoA.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("PCoA.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("PCoA.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
elif beta_representation == 'NMDS_grouped_by_condition':
mds = MDS(metric=False, random_state=0) # MDS object which perform nonmetric MDS with reproducibility
beta_diversity_array = np.array(beta_diversity[:][:])
mds_results = mds.fit_transform(beta_diversity_array) # Performs MDS transformation
stress = mds.stress_
print(stress)
if 0.1 <= stress <= 0.2:
mds = MDS(n_components=3, metric=False, random_state=0) # NMDS calculate with 3 dimensions
mds_results = mds.fit_transform(beta_diversity_array) # Performs NMDS transformation
stress = mds.stress_
print(stress)
number_condition = input("How many conditions do you have? (one / two): ")
if number_condition == 'one':
if stress <= 0.2:
mds_results_df = pd.DataFrame(mds_results, columns=['Dimension 1', 'Dimension 2'])
# Replace row and column numbers with sample names
sample_names = list(asv_info.columns[1:])
mds_results_df.index = sample_names
# Retrieve the structure containing sample data based on conditions
conditions=one_condition_struct(asv_info, condition)
# Extract legend labels from conditions
legend_labels = []
for condition,samples_list in conditions.items():
legend_labels.append(condition)
# Make unique colors from seaborn palette rainbow for each condition
color = sns.color_palette('rainbow', n_colors=len(legend_labels)) # Generate as many colors as conditions
colors = {}
for i, condition in enumerate(legend_labels):
colors[condition] = color[i]
# Assign a color to each sample based on its condition
sample_colors = [colors[condition] for sample in mds_results_df.index for condition, samples_list in conditions.items() if sample in samples_list]
# Scatter plot with samples colored by condition
plt.figure(figsize=(12, 8))
for condition, color in colors.items():
indices = [i for i, sample in enumerate(mds_results_df.index) if sample in conditions[condition]]
plt.scatter(mds_results_df.iloc[indices, 0], mds_results_df.iloc[indices, 1], color=color, label=condition)
# Add ellipses for each condition
if len(indices) > 1: # Need at least 2 points to fit an ellipse
x_coords = mds_results_df.iloc[indices, 0]
y_coords = mds_results_df.iloc[indices, 1]
# Calculate covariance matrix and mean
cov = np.cov(x_coords, y_coords)
print(cov)
eigenvalues, eigenvectors = np.linalg.eig(cov)
eigenvalues = np.sqrt(eigenvalues)
# Calculate angle of ellipse
angle = np.rad2deg(np.arctan2(*eigenvectors[:, 0][::-1]))
# Create an ellipse
ellipse = Ellipse(xy=(np.mean(x_coords), np.mean(y_coords)),
width=eigenvalues[0] * 2, height=eigenvalues[1] * 2,
angle=angle, edgecolor=color, facecolor=color, alpha=0.2, lw=2)
plt.gca().add_patch(ellipse)
plt.legend(title='Conditions', loc='best', bbox_to_anchor=(1, 1))
plt.xlabel('NMDS 1')
plt.ylabel('NMDS 2')
plt.title('NMDS Plot')
plt.annotate(f'Stress: {stress:.4f}', xy=(0.83, -0.06), xycoords='axes fraction', fontsize=12,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5)) # Add an annotation for the stress value
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG")
if format_file == 'PDF':
plt.savefig("NMDS_one_condition.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("NMDS_one_condition.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("NMDS_one_condition.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
else :
print("The stress variable is ", stress, ". It's greater than 0.2. Perform a PCoA analysis instead.")
beta_diversity_df=beta_diversity.to_data_frame()
# Replace row and column numbers with sample names
sample_names = list(asv_info.columns[1:])
beta_diversity_df.index = sample_names
beta_diversity_df.columns = sample_names
conditions=one_condition_struct(asv_info, condition) # Retrieve the structure containing sample data based on conditions
# Perform PCoA on beta diversity
pcoa_results = pcoa(beta_diversity)
pcoa_results.samples.index = sample_names
# Extract legend labels from conditions
legend_labels = []
for condition,samples_list in conditions.items():
legend_labels.append(condition)
# Make unique colors from seaborn palette rainbow for each condition
color = sns.color_palette('rainbow', n_colors=len(legend_labels)) # Generate as many colors as conditions
colors = {}
for i, condition in enumerate(legend_labels):
colors[condition] = color[i]
# Assign a color to each sample based on its condition
sample_colors = [colors[condition] for sample in pcoa_results.samples.index for condition, samples_list in conditions.items() if sample in samples_list]
# Scatter plot with samples colored by condition
plt.figure(figsize=(12, 8))
for condition, color in colors.items():
indices = [i for i, sample in enumerate(pcoa_results.samples.index) if sample in conditions[condition]] # Find indices of samples belonging to the current condition
plt.scatter(pcoa_results.samples.iloc[indices, 0], pcoa_results.samples.iloc[indices, 1], color=color, label=condition) # Scatter plot the samples belonging to the current condition using PC1 and PC2 coordinates
# Add ellipses for each condition
if len(indices) > 1: # Need at least 2 points to fit an ellipse
x_coords = pcoa_results.samples.iloc[indices, 0]
y_coords = pcoa_results.samples.iloc[indices, 1]
# Calculate covariance matrix and mean
cov = np.cov(x_coords, y_coords)
print(cov)
eigenvalues, eigenvectors = np.linalg.eig(cov)
eigenvalues = np.sqrt(eigenvalues)
# Calculate angle of ellipse
angle = np.rad2deg(np.arctan2(*eigenvectors[:, 0][::-1]))
# Create an ellipse
ellipse = Ellipse(xy=(np.mean(x_coords), np.mean(y_coords)),
width=eigenvalues[0] * 2, height=eigenvalues[1] * 2,
angle=angle, edgecolor=color, facecolor=color, alpha=0.2, lw=2)
plt.gca().add_patch(ellipse)
plt.legend(title='Conditions', loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.title('PCoA Plot')
plt.xlabel('Axis 1')
plt.ylabel('Axis 2')
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG")
if format_file == 'PDF':
plt.savefig("PCoA_one_condition.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("PCoA_one_condition.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("PCoA_one_condition.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
elif number_condition == 'two':
if stress <= 0.2:
mds_results_df = pd.DataFrame(mds_results, columns=['Dimension 1', 'Dimension 2'])
# Replace row and column numbers with sample names
sample_names = list(asv_info.columns[1:])
mds_results_df.index = sample_names
# Retrieve the structure containing sample data based on conditions
conditions, cond1_name, cond2_name=two_conditions_struct(asv_info, condition) # Get the list of samples associated with the condition
# Extract legend labels from conditions
legend_labels1 = sorted(set(condition1 for condition1, _ in conditions.keys()))
legend_labels2 = sorted(set(condition2 for _, condition2 in conditions.keys()))
# Make unique colors from seaborn palette rainbow for each value of the first condition
color_palette = sns.color_palette('rainbow', n_colors=len(legend_labels1))
colors = {label: color for label, color in zip(legend_labels1, color_palette)}
print(colors)
# Define markers for the second condition
markers = ['o', 's', '^', 'D', 'v', '*', '<', '>', 'p', 'h', 'H', '8', 'd', '1', '2', '3', '4', '8', '+', 'x', 'X', '|', '_'] # Different marker styles
marker_styles = {label: marker for label, marker in zip(legend_labels2, markers)}
print(marker_styles)
# Assign colors and markers to each sample based on its conditions
sample_colors = [colors[condition1] for sample in mds_results_df.index for (condition1, condition2), samples_list in conditions.items() if sample in samples_list]
sample_markers = [marker_styles[condition2] for sample in mds_results_df.index for (condition1, condition2), samples_list in conditions.items() if sample in samples_list]
# Scatter plot with samples colored by the first condition and shaped by the second condition
plt.figure(figsize=(12, 8))
for cond1 in legend_labels1:
for cond2 in legend_labels2:
indices = [i for i, sample in enumerate(mds_results_df.index) if sample in conditions.get((cond1, cond2), [])]
if indices:
plt.scatter(mds_results_df.iloc[indices, 0], mds_results_df.iloc[indices, 1], color=[colors[cond1]]*len(indices), marker=marker_styles[cond2], label=f'{cond1}, {cond2}')
# Create custom handles for the legend
handles_colors = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[cond], markersize=10, label=cond) for cond in legend_labels1]
handles_markers = [plt.Line2D([0], [0], marker=marker_styles[cond], color='w', markerfacecolor='grey', markersize=10, label=cond) for cond in legend_labels2]
# Add text labels for condition names
handles = ([plt.Line2D([], [], color='none', label=cond1_name)] + handles_colors + [plt.Line2D([], [], color='none', label=cond2_name)] + handles_markers)
labels = ([cond1_name] + legend_labels1 + [cond2_name] + legend_labels2)
plt.legend(handles=handles, labels=labels, title='Conditions : ', loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.xlabel('NMDS 1')
plt.ylabel('NMDS 2')
plt.title('NMDS Plot')
plt.annotate(f'Stress: {stress:.4f}', xy=(0.83, -0.06), xycoords='axes fraction', fontsize=12,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5)) # Add an annotation for the stress value
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG")
if format_file == 'PDF':
plt.savefig("NMDS_two_condition.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("NMDS_two_condition.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("NMDS_two_condition.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
else :
print("The stress variable is ", stress, ". It's greater than 0.2. Perform a PCoA analysis instead.")
beta_diversity_df=beta_diversity.to_data_frame()
# Replace row and column numbers with sample names
sample_names = list(asv_info.columns[1:])
beta_diversity_df.index = sample_names
beta_diversity_df.columns = sample_names
conditions, cond1_name, cond2_name =two_conditions_struct(asv_info, condition) # Get the list of samples associated with the condition
# Perform PCoA on beta diversity
pcoa_results = pcoa(beta_diversity)
pcoa_results.samples.index = sample_names
# Extract legend labels from conditions
legend_labels1 = sorted(set(condition1 for condition1, _ in conditions.keys()))
legend_labels2 = sorted(set(condition2 for _, condition2 in conditions.keys()))
# Make unique colors from seaborn palette rainbow for each value of the first condition
color_palette = sns.color_palette('rainbow', n_colors=len(legend_labels1))
colors = {label: color for label, color in zip(legend_labels1, color_palette)}
print(colors)
# Define markers for the second condition
markers = ['o', 's', '^', 'D', 'v', '*', '<', '>', 'p', 'h', 'H', '8', 'd', '1', '2', '3', '4', '8', '+', 'x', 'X', '|', '_'] # Different marker styles
marker_styles = {label: marker for label, marker in zip(legend_labels2, markers)}
print(marker_styles)
# Assign colors and markers to each sample based on its conditions
sample_colors = [colors[condition1] for sample in pcoa_results.samples.index for (condition1, condition2), samples_list in conditions.items() if sample in samples_list]
sample_markers = [marker_styles[condition2] for sample in pcoa_results.samples.index for (condition1, condition2), samples_list in conditions.items() if sample in samples_list]
# Scatter plot with samples colored by the first condition and shaped by the second condition
plt.figure(figsize=(12, 8))
for cond1 in legend_labels1:
for cond2 in legend_labels2:
indices = [i for i, sample in enumerate(pcoa_results.samples.index) if sample in conditions.get((cond1, cond2), [])]
if indices:
plt.scatter(pcoa_results.samples.iloc[indices, 0], pcoa_results.samples.iloc[indices, 1], color=[colors[cond1]]*len(indices), marker=marker_styles[cond2], label=f'{cond1}, {cond2}')
# Create custom handles for the legend
handles_colors = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[cond], markersize=10, label=cond) for cond in legend_labels1]
handles_markers = [plt.Line2D([0], [0], marker=marker_styles[cond], color='w', markerfacecolor='grey', markersize=10, label=cond) for cond in legend_labels2]
# Add text labels for condition names
handles = ([plt.Line2D([], [], color='none', label=cond1_name)] + handles_colors + [plt.Line2D([], [], color='none', label=cond2_name)] + handles_markers)
labels = ([cond1_name] + legend_labels1 + [cond2_name] + legend_labels2)
plt.legend(handles=handles, labels=labels, title='Conditions : ', loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.title('PCoA Plot')
plt.xlabel('Axis 1')
plt.ylabel('Axis 2')
format_file = input("Which file format would you like to save the plot? SVG or PDF or PNG")
if format_file == 'PDF':
plt.savefig("PCoA_two_condition.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("PCoA_two_condition.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("PCoA_two_condition.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
elif beta_representation == 'PCoA': # Create a PCoA plot of beta diversity, colored by samples
pcoa_results = pcoa(beta_diversity) # Perform PCoA
sample_names = list(asv_info.columns[1:]) # Extract sample name
pcoa_results.samples.index = sample_names # Set the IDs of PCoA results samples to the sample names
colors = sns.color_palette('rainbow', n_colors=len(sample_names)) # Generate as many colors as samples
legend_handles = []
plt.figure(figsize=(12, 8))
for sample, color in zip(sample_names, colors): # Iterate over sample names and corresponding colors
x = pcoa_results.samples.loc[sample, 'PC1']
y = pcoa_results.samples.loc[sample, 'PC2']
handle = plt.scatter(x, y, color=color, label=sample)
legend_handles.append(handle)
plt.legend(handles=legend_handles, title='Samples', loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.title('PCoA Plot')
plt.xlabel('Axis 1')
plt.ylabel('Axis 2')
format_file = input("Which file format would you like to save the plot ? SVG or PDF or PNG")
if format_file == 'PDF':
plt.savefig("PCoA.pdf", format='pdf', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'SVG':
plt.savefig("PCoA.svg", format='svg', bbox_inches='tight', pad_inches=0.2)
elif format_file == 'PNG':
plt.savefig("PCoA.png", format='png', bbox_inches='tight', pad_inches=0.2)
plt.show()
elif beta_representation == 'PCoA_grouped_by_condition': # Create a PCoA plot of beta diversity grouped by conditions
beta_diversity_df=beta_diversity.to_data_frame()
# Replace row and column numbers with sample names
sample_names = list(asv_info.columns[1:])
beta_diversity_df.index = sample_names
beta_diversity_df.columns = sample_names
number_condition = input("How many conditions do you have? (one / two): ")
if number_condition == 'one':
conditions=one_condition_struct(asv_info, condition) # Retrieve the structure containing sample data based on conditions
# Perform PCoA on beta diversity
pcoa_results = pcoa(beta_diversity)
pcoa_results.samples.index = sample_names
# Extract legend labels from conditions
legend_labels = []
for condition,samples_list in conditions.items():
legend_labels.append(condition)
# Make unique colors from seaborn palette rainbow for each condition
color = sns.color_palette('rainbow', n_colors=len(legend_labels)) # Generate as many colors as conditions
colors = {}
for i, condition in enumerate(legend_labels):
colors[condition] = color[i]
# Assign a color to each sample based on its condition
sample_colors = [colors[condition] for sample in pcoa_results.samples.index for condition, samples_list in conditions.items() if sample in samples_list]
# Determine which axes to plot
x_axis = int(input("Which PCoA axis do you want to display on x-axis ( axis 1 = 0 / axis 2 = 1 / axis 3 = 2) ?"))
y_axis = int(input("Which PCoA axis do you want to display on y-axis ( axis 1 = 0 / axis 2 = 1 / axis 3 = 2) ?"))
# Scatter plot with samples colored by condition
plt.figure(figsize=(12, 8))
for condition, color in colors.items():
indices = [i for i, sample in enumerate(pcoa_results.samples.index) if sample in conditions[condition]] # Find indices of samples belonging to the current condition
plt.scatter(pcoa_results.samples.iloc[indices, x_axis], pcoa_results.samples.iloc[indices, y_axis], color=color, label=condition) # Scatter plot the samples belonging to the current condition using PC1 and PC2 coordinates
# Add ellipses for each condition
if len(indices) > 1: # Need at least 2 points to fit an ellipse
x_coords = pcoa_results.samples.iloc[indices, x_axis]
y_coords = pcoa_results.samples.iloc[indices, y_axis]
# Calculate covariance matrix and mean
cov = np.cov(x_coords, y_coords)
eigenvalues, eigenvectors = np.linalg.eig(cov)
eigenvalues = np.sqrt(eigenvalues)
# Calculate angle of ellipse
angle = np.rad2deg(np.arctan2(*eigenvectors[:, 0][::-1]))
# Create an ellipse
ellipse = Ellipse(xy=(np.mean(x_coords), np.mean(y_coords)),