DF/DN on CIFAR¶
[1]:
# Import necessary packages
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
[2]:
# Define color palette
sns.set(color_codes=True, style="white", context="talk", font_scale=1.5)
[3]:
def load_result(filename):
"""
Loads results from specified file
"""
inputs = open(filename, "r")
lines = inputs.readlines()
ls = []
for line in lines:
ls.append(float(line.strip()))
return ls
def load_results(prefix):
"""
Loads results from specified files
"""
acc_ls = []
time_ls = []
for name in names:
acc_ls.append(load_result(prefix + name + ".txt"))
time_ls.append(load_result(prefix + name + "_train_time.txt"))
return acc_ls, time_ls
def load_results_st(prefix):
"""
Loads results from similar time files
"""
acc_ls = []
time_ls = []
for name in names:
if name != "naive_rf" and name != "svm":
acc_ls.append(load_result(prefix + name + "_st.txt"))
time_ls.append(load_result(prefix + name + "_train_time_st.txt"))
else:
acc_ls.append(load_result(prefix + name + ".txt"))
time_ls.append(load_result(prefix + name + "_train_time.txt"))
return acc_ls, time_ls
def load_results_sc(prefix):
"""
Loads results from similar cost files
"""
acc_ls = []
time_ls = []
for name in names:
if name != "naive_rf" and name != "svm":
acc_ls.append(load_result(prefix + name + "_sc.txt"))
time_ls.append(load_result(prefix + name + "_train_time_sc.txt"))
else:
acc_ls.append(load_result(prefix + name + ".txt"))
time_ls.append(load_result(prefix + name + "_train_time_lc.txt"))
return acc_ls, time_ls
def produce_mean(ls):
"""
Produces means from list of 8 results
"""
ls_space = []
for i in range(int(len(ls) / 8)):
l = ls[i * 8 : (i + 1) * 8]
ls_space.append(l)
return np.mean(ls_space, axis=0)
[4]:
def plot_acc(col, accs, pos, samples_space):
# Plot low alpha results
for k in range(45):
col.plot(
samples_space,
accs[pos][0][k * 8 : (k + 1) * 8],
color="#e41a1c",
alpha=0.1,
)
col.plot(
samples_space,
accs[pos][1][k * 8 : (k + 1) * 8],
color="#377eb8",
alpha=0.1,
)
col.plot(
samples_space,
accs[pos][2][k * 8 : (k + 1) * 8],
color="#377eb8",
linestyle="dashed",
alpha=0.1,
)
col.plot(
samples_space,
accs[pos][3][k * 8 : (k + 1) * 8],
color="#377eb8",
linestyle="dotted",
alpha=0.1,
)
col.plot(
samples_space,
accs[pos][4][k * 8 : (k + 1) * 8],
color="#4daf4a",
alpha=0.1,
)
if pos == 0:
# Plot mean results
col.plot(
samples_space,
produce_mean(accs[pos][1]),
linewidth=5,
color="#377eb8",
label="CNN-1L",
)
col.plot(
samples_space,
produce_mean(accs[pos][0]),
linewidth=5,
color="#e41a1c",
label="RF",
)
col.plot(
samples_space,
produce_mean(accs[pos][2]),
linewidth=5,
color="#377eb8",
linestyle="dashed",
label="CNN-2L",
)
col.plot(
samples_space,
produce_mean(accs[pos][4]),
linewidth=5,
color="#4daf4a",
label="ResNet-18",
)
col.plot(
samples_space,
produce_mean(accs[pos][3]),
linewidth=5,
color="#377eb8",
linestyle="dotted",
label="CNN-5L",
)
else:
col.plot(
samples_space,
produce_mean(accs[pos][1]),
linewidth=5,
color="#377eb8",
)
col.plot(
samples_space,
produce_mean(accs[pos][0]),
linewidth=5,
color="#e41a1c",
)
col.plot(
samples_space,
produce_mean(accs[pos][4]),
linewidth=5,
color="#4daf4a",
)
col.plot(
samples_space,
produce_mean(accs[pos][2]),
linewidth=5,
color="#377eb8",
linestyle="dashed",
)
col.plot(
samples_space,
produce_mean(accs[pos][3]),
linewidth=5,
color="#377eb8",
linestyle="dotted",
)
def plot_six():
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(17, 11), constrained_layout=True)
fig.text(0.53, -0.05, "Number of Train Samples", ha="center")
xtitles = ["3 Classes", "8 Classes", "90 Classes"]
ytitles = ["Accuracy", "Wall Time (s)"]
ylimits = [[0, 1], [1e-2, 1e2]]
yticks = [[0, 0.5, 1], [1e-2, 1e0, 1e2]]
for i, row in enumerate(ax):
for j, col in enumerate(row):
count = 3 * i + j
col.set_xscale("log")
col.set_ylim(ylimits[i])
if count == 2 or count == 5:
samples_space = np.geomspace(100, 10000, num=8, dtype=int)
else:
samples_space = np.geomspace(10, 10000, num=8, dtype=int)
# Label x axis and plot figures
if count < 3:
col.set_xticks([])
col.set_title(xtitles[j])
plot_acc(col, accs, j, samples_space)
else:
if count == 5:
col.set_xticks([1e2, 1e3, 1e4])
else:
col.set_xticks([1e1, 1e2, 1e3, 1e4])
col.set_yscale("log")
plot_acc(col, accs, j + 3, samples_space)
# Label y axis
if count % 3 == 0:
col.set_yticks(yticks[i])
col.set_ylabel(ytitles[i])
else:
col.set_yticks([])
fig.align_ylabels(
ax[
:,
]
)
leg = fig.legend(
bbox_to_anchor=(0.53, -0.2),
bbox_transform=plt.gcf().transFigure,
ncol=3,
loc="lower center",
)
leg.get_frame().set_linewidth(0.0)
for legobj in leg.legendHandles:
legobj.set_linewidth(5.0)
DF/DN with Unbounded Time & Cost¶
[5]:
directory = "../benchmarks/vision/"
names = ["naive_rf", "cnn32", "cnn32_2l", "cnn32_5l", "resnet18"]
# Load 3-classes results
acc_3, time_3 = load_results(directory + "3_class/")
# Load 8-classes results
acc_8, time_8 = load_results(directory + "8_class/")
# Load 90-classes results
acc_90, time_90 = load_results(directory + "90_class/")
accs = [acc_3, acc_8, acc_90, time_3, time_8, time_90]
plot_six()
plt.savefig("../paper/figures/cifar.pdf", transparent=True, bbox_inches="tight")
DF/DN with Fixed Training Time¶
[6]:
# Load 3-classes results
acc_3, time_3 = load_results_st(directory + "3_class/")
# Load 8-classes results
acc_8, time_8 = load_results_st(directory + "8_class/")
# Load 90-classes results
acc_90, time_90 = load_results_st(directory + "90_class/")
accs = [acc_3, acc_8, acc_90, time_3, time_8, time_90]
plot_six()
plt.savefig("../paper/figures/cifar_st.pdf", transparent=True, bbox_inches="tight")
DF/DN with Fixed Training Cost¶
[7]:
# Load 3-classes results
acc_3, time_3 = load_results_sc(directory + "3_class/")
# Load 8-classes results
acc_8, time_8 = load_results_sc(directory + "8_class/")
# Load 90-classes results
acc_90, time_90 = load_results_sc(directory + "90_class/")
accs = [acc_3, acc_8, acc_90, time_3, time_8, time_90]
plot_six()
plt.savefig("../paper/figures/cifar_sc.pdf", transparent=True, bbox_inches="tight")