99 lines
3.6 KiB
Python
99 lines
3.6 KiB
Python
import pandas as pd
|
|
from sklearn.neighbors import NearestNeighbors
|
|
from sklearn.preprocessing import StandardScaler
|
|
import numpy as np
|
|
|
|
train_data = pd.read_csv(BaseConfig.TRAIN_DATA)
|
|
test_data = pd.read_csv(BaseConfig.TEST_DATA)
|
|
seaphage_prob = pd.read_csv(BaseConfig.SEAPHAGE_PROB)
|
|
|
|
train_seaphage_true = train_data[(train_data['source'] == 'SEAPHAGES') & (train_data['label']==1)]
|
|
test_seaphage_true = test_data[(test_data['source'] == 'SEAPHAGES') & (test_data['label']==1)]
|
|
all_seaphage_true = pd.concat([train_seaphage_true, test_seaphage_true])
|
|
all_seaphage_true['match_id'] = all_seaphage_true['DNA_seqid'].str.split('_').str[0]
|
|
seaphage_prob['match_id'] = seaphage_prob['DNA_seqid'].str.split('_').str[0]
|
|
|
|
all_seaphage_true = all_seaphage_true.merge(
|
|
seaphage_prob[['match_id', 'final_rank']],
|
|
on='match_id',
|
|
how='left'
|
|
)
|
|
all_seaphage_true = all_seaphage_true.rename(columns={'final_rank': 'confidence'})
|
|
all_seaphage_true = all_seaphage_true.drop('match_id', axis=1)
|
|
|
|
# 1. 样本分类
|
|
high_conf_samples = all_seaphage_true[all_seaphage_true['confidence'] <= 5]
|
|
medium_conf_samples = all_seaphage_true[all_seaphage_true['confidence'] == 6]
|
|
low_conf_samples = all_seaphage_true[all_seaphage_true['confidence'] > 6]
|
|
|
|
def sequence_to_onehot(sequence):
|
|
"""将DNA序列转换为独热编码"""
|
|
# DNA碱基映射
|
|
base_dict = {
|
|
'A': [1,0,0,0],
|
|
'T': [0,1,0,0],
|
|
'C': [0,0,1,0],
|
|
'G': [0,0,0,1],
|
|
'N': [0,0,0,0] # 处理未知碱基
|
|
}
|
|
|
|
# 转换序列
|
|
encoding = []
|
|
for base in sequence.upper():
|
|
encoding.extend(base_dict.get(base, [0,0,0,0]))
|
|
|
|
return np.array(encoding)
|
|
|
|
# 1. 准备特征
|
|
# 将序列转换为独热编码
|
|
high_sequences = np.vstack(high_conf_samples['FS_period'].apply(sequence_to_onehot))
|
|
medium_sequences = np.vstack(medium_conf_samples['FS_period'].apply(sequence_to_onehot))
|
|
low_sequences = np.vstack(low_conf_samples['FS_period'].apply(sequence_to_onehot))
|
|
|
|
# 2. 数据标准化
|
|
scaler = StandardScaler()
|
|
X_high = scaler.fit_transform(high_sequences)
|
|
X_medium = scaler.transform(medium_sequences)
|
|
X_low = scaler.transform(low_sequences)
|
|
|
|
# 3. KNN分析
|
|
n_neighbors = 5
|
|
knn = NearestNeighbors(n_neighbors=n_neighbors)
|
|
knn.fit(X_high)
|
|
|
|
# 计算距离
|
|
distances_medium, _ = knn.kneighbors(X_medium)
|
|
distances_low, _ = knn.kneighbors(X_low)
|
|
mean_distances_medium = distances_medium.mean(axis=1)
|
|
mean_distances_low = distances_low.mean(axis=1)
|
|
|
|
# 4. 设置阈值和选择样本
|
|
medium_threshold = np.percentile(mean_distances_medium, 80)
|
|
low_threshold = np.percentile(mean_distances_low, 50)
|
|
|
|
selected_medium_mask = mean_distances_medium <= medium_threshold
|
|
selected_low_mask = mean_distances_low <= low_threshold
|
|
|
|
selected_medium_samples = medium_conf_samples[selected_medium_mask]
|
|
selected_low_samples = low_conf_samples[selected_low_mask]
|
|
|
|
# 5. 合并最终选择的样本
|
|
final_selected_samples = pd.concat([
|
|
high_conf_samples,
|
|
selected_medium_samples,
|
|
selected_low_samples
|
|
])
|
|
|
|
# 6. 添加可信度标签
|
|
final_selected_samples['confidence_level'] = 'low'
|
|
final_selected_samples.loc[final_selected_samples['confidence'] <= 5, 'confidence_level'] = 'high'
|
|
final_selected_samples.loc[final_selected_samples['confidence'] == 6, 'confidence_level'] = 'medium'
|
|
final_selected_samples.to_csv('/mnt/lmpbe/guest01/PRF-V4/训练数据/seaphage_selected.csv', index=False)
|
|
# 7. 输出统计信息
|
|
print(f"高可信样本数量: {len(high_conf_samples)}")
|
|
print(f"选中的中度可信样本数量: {len(selected_medium_samples)}")
|
|
print(f"选中的低度可信样本数量: {len(selected_low_samples)}")
|
|
print(f"最终选择的总样本数量: {len(final_selected_samples)}")
|
|
|
|
|