Skip to content

Commit

Permalink
Merge pull request #340 from youguohui/patch-1
Browse files Browse the repository at this point in the history
Update dataset_transform.py
  • Loading branch information
yangapku authored Aug 6, 2024
2 parents 8cbef86 + 4e117f1 commit 85f3fa3
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions dataset_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@
from io import BytesIO
from sklearn.model_selection import train_test_split


# 读取csv文件
'''
original_dataset原始数据的路径文件夹,需修改为实际的路径
'''

#训练和验证集文本数据的文件
data1 = pd.read_csv('original_dataset/data1/ImageWordData.csv')
#训练和验证集图像数据的目录
data1_images_folder='original_dataset/data1/ImageData'

# 先将文本及对应图像id划分划分训练集和验证集
train_data, val_data = train_test_split(data1, test_size=0.2, random_state=42)

# 创建函数来处理数据集,使文本关联到其对应图像id的图像
def process_train_valid(data, img_file, txt_file):
def process_train_valid(data, images_folder, img_file, txt_file):
with open(img_file, 'w') as f_img, open(txt_file, 'w') as f_txt:
for index, row in data.iterrows():
# 图片内容需要被编码为base64格式
img_path = os.path.join('original_dataset/data1/ImageData', row['image_id'])
img_path = os.path.join(images_folder, row['image_id'])
with open(img_path, 'rb') as f_img_file:
img = Image.open(f_img_file)
img_buffer = BytesIO()
Expand All @@ -36,21 +38,24 @@ def process_train_valid(data, img_file, txt_file):
f_txt.write(json.dumps(text_data) + '\n')

# 处理训练集和验证集
process_train_valid(train_data, 'Chinese-CLIP/datasets/DatasetName/train_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/train_texts.jsonl')
process_train_valid(val_data, 'Chinese-CLIP/datasets/DatasetName/valid_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/valid_texts.jsonl')
# datasets/DatasetName为在Chinese-CLIP项目目录下新建的存放转换后数据集的文件夹
process_train_valid(train_data, data1_images_folder, 'Chinese-CLIP/datasets/DatasetName/train_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/train_texts.jsonl')
process_train_valid(val_data, data1_images_folder, 'Chinese-CLIP/datasets/DatasetName/valid_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/valid_texts.jsonl')



#制作从文本到图像(Text_to_Image)检索时的,测试集。data2为Text_to_Image测试数据文件夹名
# 制作从文本到图像(Text_to_Image)检索时的,测试集。data2为Text_to_Image测试数据文件夹名
image_data2 = pd.read_csv('original_dataset/data2/image_data.csv')
word_test2 = pd.read_csv('original_dataset/data2/word_test.csv')
# 原始图像测试集目录
data2_images_folder='original_dataset/data2/ImageData'

# 处理Text_to_Image测试集
def process_text_to_image(image_data, word_test, img_file, txt_file):
def process_text_to_image(image_data, images_folder, word_test, img_file, txt_file):
with open(img_file, 'w') as f_img, open(txt_file, 'w') as f_txt:
for index, row in image_data.iterrows():
# 图片内容需要被编码为base64格式
img_path = os.path.join('../dataset/data2/ImageData', row['image_id'])
img_path = os.path.join(images_folder, row['image_id'])
with open(img_path, 'rb') as f_img_file:
img = Image.open(f_img_file)
img_buffer = BytesIO()
Expand All @@ -65,20 +70,23 @@ def process_text_to_image(image_data, word_test, img_file, txt_file):
text_data = {"text_id": row["text_id"], "text": row["caption"], "image_ids": []}
f_txt.write(json.dumps(text_data) + '\n')

process_text_to_image(image_data2, word_test2, 'Chinese-CLIP/datasets/DatasetName/test2_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test2_texts.jsonl')
# datasets/DatasetName为在Chinese-CLIP项目目录下新建的存放转换后数据集的文件夹
process_text_to_image(image_data2, data2_images_folder, word_test2, 'Chinese-CLIP/datasets/DatasetName/test2_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test2_texts.jsonl')



#制作从图像到文本(Image_to_Text)检索时的,测试集。data3为Image_to_Text测试数据文件夹名
# 制作从图像到文本(Image_to_Text)检索时的,测试集。data3为Image_to_Text测试数据文件夹名
image_test3 = pd.read_csv('original_dataset/data3/image_test.csv')
word_data3 = pd.read_csv('original_dataset/data3/word_data.csv')
# 原始图像测试集目录
data3_images_folder='original_dataset/data3/ImageData'

# 处理Image_to_Text测试集集
def process_image_to_text(image_data, word_test, img_file, txt_file):
def process_image_to_text(image_data, images_folder, word_test, img_file, txt_file):
with open(img_file, 'w') as f_img, open(txt_file, 'w') as f_txt:
for index, row in image_data.iterrows():
# 图片内容需要被编码为base64格式
img_path = os.path.join('../dataset/data3/ImageData', row['image_id'])
img_path = os.path.join(images_folder, row['image_id'])
with open(img_path, 'rb') as f_img_file:
img = Image.open(f_img_file)
img_buffer = BytesIO()
Expand All @@ -93,7 +101,8 @@ def process_image_to_text(image_data, word_test, img_file, txt_file):
text_data = {"text_id": row["text_id"], "text": row["caption"], "image_ids": []}
f_txt.write(json.dumps(text_data) + '\n')

process_image_to_text(image_test3, word_data3, 'Chinese-CLIP/datasets/DatasetName/test3_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test3_texts.jsonl')
# datasets/DatasetName为在Chinese-CLIP项目目录下新建的存放转换后数据集的文件夹
process_image_to_text(image_test3, data3_images_folder, word_data3, 'Chinese-CLIP/datasets/DatasetName/test3_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test3_texts.jsonl')


'''
Expand Down

0 comments on commit 85f3fa3

Please sign in to comment.