Files
DDT/tools/cat_images.py
wangshuai6 06499f1caa submit code
2025-04-09 11:01:16 +08:00

44 lines
1.3 KiB
Python

import cv2
import numpy as np
import os
import pathlib
import argparse
def group_images(path_list):
sorted(path_list)
class_id_dict = {}
for path in path_list:
class_id = str(path.name).split('_')[0]
if class_id not in class_id_dict:
class_id_dict[class_id] = []
class_id_dict[class_id].append(path)
return class_id_dict
def cat_images(path_list):
imgs = []
for path in path_list:
img = cv2.imread(str(path))
os.remove(path)
imgs.append(img)
row_cat_images = []
row_length = int(len(imgs)**0.5)
for i in range(len(imgs)//row_length):
row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1))
cat_image = np.concatenate(row_cat_images, axis=0)
return cat_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, default=None)
args = parser.parse_args()
src_dir = args.src_dir
path_list = list(pathlib.Path(src_dir).glob('*.png'))
class_id_dict = group_images(path_list)
for class_id, path_list in class_id_dict.items():
cat_image = cat_images(path_list)
cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg')
# cat_path = "cat_{}.png".format(class_id)
cv2.imwrite(cat_path, cat_image)