image/main.py

248 lines
9.9 KiB
Python
Raw Permalink Normal View History

2023-09-27 14:05:51 +08:00
import simpleimageio as sio
import figuregen
from figuregen.util.templates import FullSizeWithCrops
from figuregen.util.image import Cropbox
import os
import argparse
import yaml
import numpy as np
class MyCrops(FullSizeWithCrops):
def __init__(self, reference_image, method_images, crops,
crops_below = True, method_names = None, use_latex = False,add=False):
""" Shows a reference image next to a grid of crops from different methods.
Args:
reference_image: a reference image (or any other image to put full-size in the lefthand grid)
method_images: list of images, each corresponds to a new column in the crop grid
crops: list of crops to take from each method, each creates a new row and a marker on the reference
crops_below: [optional] if False, the crops will be a column to the right of each image
method_names: [optional] list of string, names for the reference and each method, to put above the crops
use_latex: set to true to pretty-print captions with LaTeX commands (requires TikZ backend)
Returns:
A list of two grids:
The first is a single image (reference), the second a series of crops, one or more for each method.
"""
self._reference_image = reference_image
self._method_images = method_images
self.use_latex = use_latex
self._crops_below = crops_below
self._errors = [
self.compute_error(reference_image, m)
for m in method_images
]
self._crop_errors = [
[
self.compute_error(crop.crop(reference_image), crop.crop(m))
for m in method_images
]
for crop in crops
]
# Put in one list to make our life easier in the following
images = [reference_image]
images.extend(method_images)
# Create the grid for the reference image
self._ref_grid = [ figuregen.Grid(1, 1) for _ in range(len(images)) ]
for i in range(len(images)):
self._ref_grid[i][0, 0].image = self.tonemap(images[i])
for crop in crops:
self._ref_grid[i][0, 0].set_marker(crop.marker_pos, crop.marker_size, color=[255,255,255])
# Create the grid with the crops
if self._crops_below:
self._crop_grid = [
figuregen.Grid(num_cols=len(crops), num_rows=1)
for _ in range(len(images))
]
for i in range(len(images)):
for col in range(len(crops)):
self._crop_grid[i][0, col].image = self.tonemap(crops[col].crop(images[i]))
else:
self._crop_grid = [
figuregen.Grid(num_cols=1, num_rows=len(crops))
for _ in range(len(images))
]
for i in range(len(images)):
for row in range(len(crops)):
self._crop_grid[i][row, 0].image = self.tonemap(crops[row].crop(images[i]))
# Add padding to the right of all but the last image
for i in range(len(images) - 1):
self._ref_grid[i].layout.set_padding(right=1)
self._crop_grid[i].layout.set_padding(right=1)
if self._crops_below:
self._ref_grid[i].layout.set_padding(bottom=1)
if self._crops_below:
self._ref_grid[-1].layout.set_padding(bottom=1)
else:
self._ref_grid[-1].layout.set_padding(right=1)
# Put error values underneath the columns
if self._crops_below:
for i in range(len(images)):
# if i > 0:
# err = self.error_string(i - 1, self.errors)
# else:
# err = self.error_metric_name
if add:
self._crop_grid[i].set_title("bottom", method_names[i])
self._crop_grid[i].layout.set_title("bottom", 6, 1, 8)
else:
pass # TODO
# TODO this requires titles spanning multiple grids (the image and its crops)!
# error_strings = [ f"{self.error_metric_name}" ]
# error_strings.extend([ self.error_string(i, self.errors) for i in range(len(self.errors)) ])
# self._crop_grid.set_col_titles("bottom", error_strings)
# self._crop_grid.layout.set_padding(column=1, row=1)
# self._crop_grid.layout.set_col_titles("bottom", fontsize=8, field_size_mm=2.8, offset_mm=0.5)
# If given, show method names on top
# TODO combine with error values, and always show both or neither
# if method_names is not None:
# self._crop_grid.set_col_titles("top", method_names)
# self._crop_grid.layout.set_col_titles("top", fontsize=8, field_size_mm=2.8, offset_mm=0.25)
# self._ref_grid.copy_layout(self._crop_grid)
# self._ref_grid.layout.set_padding(right=1)
# TODO set appropriate paddings for alignment etc
def error_string(self, index, errors):
""" Generates the human-readable error string for the i-th element in a list of error values.
Args:
index: index in the list of errors
errors: list of error values, one per method, in order
"""
if self.use_latex and index == np.argmin(errors):
return f""
elif self.use_latex:
return f""
else:
return f""
def error_metric_name(self) -> str:
return ""
def main(data_root,image_file_1,image_file_2,image_file_3,dataset_type,ours_method,crops_1,crops_2,crops_3,width_cm,output_image,ours_method_1=None,ours_method_2=None,ours_method_3=None):
data_root = os.path.join(data_root)
original_image_1 = sio.read(os.path.join(data_root,"original",dataset_type,image_file_1))
# 写完了你在这个里面直接运行就可以
KinD_image_1= sio.read(os.path.join(data_root,"KinD",dataset_type,image_file_1))
LIME_image_1= sio.read(os.path.join(data_root,"LIME",dataset_type,image_file_1))
NPE_image_1 = sio.read(os.path.join(data_root,"NPE",dataset_type,image_file_1))
SRIE_image_1 = sio.read(os.path.join(data_root,"SRIE",dataset_type,image_file_1))
ZeroDCE_image_1 = sio.read(os.path.join(data_root,"ZeroDCE",dataset_type,image_file_1))
ours_image_1 = sio.read(os.path.join(data_root,ours_method_1 if ours_method_1 else ours_method,dataset_type,image_file_1))
original_image_2 = sio.read(os.path.join(data_root,"original",dataset_type,image_file_2))
# 写完了你在这个里面直接运行就可以
KinD_image_2= sio.read(os.path.join(data_root,"KinD",dataset_type,image_file_2))
LIME_image_2= sio.read(os.path.join(data_root,"LIME",dataset_type,image_file_2))
NPE_image_2 = sio.read(os.path.join(data_root,"NPE",dataset_type,image_file_2))
SRIE_image_2 = sio.read(os.path.join(data_root,"SRIE",dataset_type,image_file_2))
ZeroDCE_image_2 = sio.read(os.path.join(data_root,"ZeroDCE",dataset_type,image_file_2))
ours_image_2 = sio.read(os.path.join(data_root,ours_method_2 if ours_method_2 else ours_method,dataset_type,image_file_2))
original_image_3 = sio.read(os.path.join(data_root,"original",dataset_type,image_file_3))
# 写完了你在这个里面直接运行就可以
KinD_image_3= sio.read(os.path.join(data_root,"KinD",dataset_type,image_file_3))
LIME_image_3= sio.read(os.path.join(data_root,"LIME",dataset_type,image_file_3))
NPE_image_3 = sio.read(os.path.join(data_root,"NPE",dataset_type,image_file_3))
SRIE_image_3 = sio.read(os.path.join(data_root,"SRIE",dataset_type,image_file_3))
ZeroDCE_image_3 = sio.read(os.path.join(data_root,"ZeroDCE",dataset_type,image_file_3))
ours_image_3 = sio.read(os.path.join(data_root,ours_method_3 if ours_method_3 else ours_method,dataset_type,image_file_3))
output_image = os.path.abspath(output_image)
figure_1 = MyCrops(
reference_image=original_image_1,
method_images=[
KinD_image_1,
LIME_image_1,
NPE_image_1,
SRIE_image_1,
ZeroDCE_image_1,
ours_image_1
],
crops=[
Cropbox(top=crop[0], left=crop[1], height=crop[2], width=crop[3], scale=crop[4]) for crop in crops_1
],
# scene_name="Pool",
method_names=["Reference", "KinD", "LIME", "NPE", "SRIE", "ZeroDCE", "Ours"],
use_latex=True
).figure
figure_2 = MyCrops(
reference_image=original_image_2,
method_images=[
KinD_image_2,
LIME_image_2,
NPE_image_2,
SRIE_image_2,
ZeroDCE_image_2,
ours_image_2
],
crops=[
Cropbox(top=crop[0], left=crop[1], height=crop[2], width=crop[3], scale=crop[4]) for crop in crops_2
],
# scene_name="Pool",
method_names=["Reference", "KinD", "LIME", "NPE", "SRIE", "ZeroDCE", "Ours"],
use_latex=True
).figure
figure_3=MyCrops(
reference_image=original_image_3,
method_images=[
KinD_image_3,
LIME_image_3,
NPE_image_3,
SRIE_image_3,
ZeroDCE_image_3,
ours_image_3
],
crops=[
Cropbox(top=crop[0], left=crop[1], height=crop[2], width=crop[3], scale=crop[4]) for crop in crops_3
],
method_names=["Reference", "KinD", "LIME", "NPE", "SRIE", "ZeroDCE", "Ours"],
use_latex=True,
add=True
).figure
# scene_name="Pool",
figuregen.figure([figure_1[0],figure_1[1],figure_2[0],figure_2[1],figure_3[0],figure_3[1]], width_cm=width_cm, filename=output_image)
# grid = figuregen.Grid(3,1)
# grid[0,0].set_image(figure)
# grid[1,0].set_image(figure)
# grid[2,0].set_image(figure)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./config/compare3_1.yaml")
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
main(**config)