import torchvision
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import pickle
import numpy as np
import sys
sys.path.append("..")
from data.dataset import construct_datasets
from model.model import construct_model, set_random_seed
import argparse
from collections import defaultdict
import torch.nn as nn
from torchvision.datasets.folder import ImageFolder
import matplotlib.pyplot as plt
from modellib import LeNetZhu
default_transform = transforms.Compose([transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor()])
trainset = ImageFolder("../data/celeba/img_align_celeba/split_celeba/Train", transform = default_transform)
valset = ImageFolder("../data/celeba/img_align_celeba/split_celeba/Test", transform = default_transform)
ngf = 64
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(768, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, 3, 4, 2, 1, bias=False),
nn.BatchNorm2d(3),
nn.Sigmoid(),
# nn.Tanh()
# state size. (nc) x 64 x 64
nn.MaxPool2d(2,2)
)
def forward(self, input): # [1, 768, 1, 1]
return self.main(input)
class Gradinversion(nn.Module):
def __init__(self):
super(Gradinversion, self).__init__()
self.generator = Generator()
def forward(self, x): # x: torch.Size([10, 768])
x = x.reshape(x.size()[0], x.size()[1], 1, 1)
x = self.generator(x)
x = torch.mean(x, dim=0) # torch.Size([3, 32, 32])
return x
gi = Gradinversion()
class_num = 20
extra_class = 0
att_model = LeNetZhu(num_classes = class_num + extra_class, num_channels=3)
gi.load_state_dict(torch.load("./model/gi_lenet_epoch500_2000img(formal_class_20).pkl"))
<All keys matched successfully>
grad_model = att_model
dataset = trainset
gi.train()
gi.to("cpu")
grad_model.to("cpu")
sample = dataset[0][0]
plt.imshow(sample.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
pred = grad_model(sample.view(-1,3,32,32))
recon_list = []
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([0]))
grad = torch.autograd.grad(loss, grad_model.parameters(), retain_graph=True)
grad_input = grad[-2][:20].reshape(20, 768, 1, 1)
recons = gi(grad_input)
plt.imshow(recons.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
x = gi.generator(grad_input)
plt.imshow(x[0].detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
(-0.5, 31.5, 31.5, -0.5)
plt.imshow(x[1].detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
(-0.5, 31.5, 31.5, -0.5)
plt.imshow(x[19].detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
(-0.5, 31.5, 31.5, -0.5)
plt.imshow(torch.mean(x, dim=0).detach().cpu().numpy().transpose(1,2,0))
<matplotlib.image.AxesImage at 0x7fdeac09aa30>
import torchvision
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import pickle
import numpy as np
import sys
sys.path.append("..")
from data.dataset import construct_datasets
from model.model import construct_model, set_random_seed
import argparse
from collections import defaultdict
import torch.nn as nn
from torchvision.datasets.folder import ImageFolder
import matplotlib.pyplot as plt
from modellib import LeNetZhu
default_transform = transforms.Compose([transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor()])
trainset = ImageFolder("../data/celeba/img_align_celeba/split_celeba/Train", transform = default_transform)
valset = ImageFolder("../data/celeba/img_align_celeba/split_celeba/Test", transform = default_transform)
ngf = 64
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(768, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, 3, 4, 2, 1, bias=False),
nn.BatchNorm2d(3),
nn.Sigmoid(),
# nn.Tanh()
# state size. (nc) x 64 x 64
nn.MaxPool2d(2,2)
)
def forward(self, input): # [1, 768, 1, 1]
return self.main(input)
class ConvFeatureExtraction(nn.Module):
def __init__(self):
super(ConvFeatureExtraction, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3, stride=2)
self.conv2 = nn.Conv2d(1, 1, 3, stride=1)
self.linear1 = nn.Linear(169, 128)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.flatten().view(x.size()[0], -1)
x = self.linear1(x)
return x
class GradinversionWithselfAttention(nn.Module):
def __init__(self):
super(GradinversionWithselfAttention, self).__init__()
self.generator = Generator()
self.feature = ConvFeatureExtraction()
self.multihead_attn = nn.MultiheadAttention(embed_dim = 128, num_heads = 1, dropout=0.0)
self.softmax = nn.Softmax(dim=1)
self.Q = nn.Linear(128, 128)
self.K = nn.Linear(128, 128)
self.V = nn.Linear(128, 128)
def forward(self, x): # x: torch.Size([10, 768])
x = x.reshape(x.size()[0], x.size()[1], 1, 1)
x = self.generator(x)
qkv = self.feature(x)
qkv = qkv.unsqueeze(dim=0)
q = self.Q(qkv)
k = self.K(qkv)
v = self.V(qkv)
attn_output, attn_output_weights = self.multihead_attn(q, k, v)
weight = attn_output_weights.reshape(1,-1)
weight = self.softmax(weight)
weight = weight.squeeze()
#x = torch.mean(x, dim=0) # torch.Size([3, 32, 32])
res = 0
for i in range(x.size(0)):
res += weight[i]*x[i]
return res, weight
def train(grad_model, gi_model, trainset, optimizer, loss_fn, device, num_train, class_num=10):
grad_model.to(device)
gi_model.to(device)
global_loss = 0
for i in range(num_train): # len(trainset)
sample = trainset[i][0].to(device)
pred = grad_model(sample.view(-1,3,32,32))
r_loss = 0
for j in range(class_num):
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([j]).to(device))
grad = torch.autograd.grad(loss, grad_model.parameters(), retain_graph=True)
grad_input = grad[-2].reshape(class_num, 768, 1, 1).to(device)
optimizer.zero_grad()
reconstruction_loss = loss_fn(gi_model(grad_input)[0], sample)
reconstruction_loss.backward()
optimizer.step()
r_loss += reconstruction_loss.item()
global_loss += r_loss
# print(f"i : {i}, Loss : {r_loss}")
return global_loss
#query = torch.randn(1, 20, 128)
#multihead_attn = nn.MultiheadAttention(embed_dim = 128, num_heads = 1, dropout=0.0)
#attn_output, attn_output_weights = multihead_attn(query, query, query)
class_num = 10
att_model = LeNetZhu(num_classes = class_num, num_channels=3)
gi = GradinversionWithselfAttention()
optimizer = torch.optim.Adam(gi.parameters(), lr = 0.01)
hist = []
for epoch in range(100):
global_loss = train(att_model, gi, trainset, optimizer, torch.nn.functional.binary_cross_entropy, "cuda",
num_train=100, class_num=class_num)
print(f"epoch: {epoch}, global loss: {global_loss}")
hist.append(global_loss)
epoch: 0, global loss: 671.5524678230286 epoch: 1, global loss: 663.8371953070164 epoch: 2, global loss: 659.5930935740471 epoch: 3, global loss: 658.2926070690155 epoch: 4, global loss: 655.612980723381 epoch: 5, global loss: 653.8608857393265 epoch: 6, global loss: 653.2967808246613 epoch: 7, global loss: 653.0444711744785 epoch: 8, global loss: 652.794310092926 epoch: 9, global loss: 652.5177665650845 epoch: 10, global loss: 652.4399910867214 epoch: 11, global loss: 652.2245920300484 epoch: 12, global loss: 652.097874224186 epoch: 13, global loss: 651.7978120744228 epoch: 14, global loss: 651.8684236705303 epoch: 15, global loss: 651.8364308774471 epoch: 16, global loss: 651.5037378966808 epoch: 17, global loss: 650.7593080401421 epoch: 18, global loss: 649.2994719147682 epoch: 19, global loss: 646.6307125091553 epoch: 20, global loss: 640.8626713752747 epoch: 21, global loss: 634.5940164327621 epoch: 22, global loss: 624.4373970329762 epoch: 23, global loss: 616.9449944198132 epoch: 24, global loss: 612.4960078299046 epoch: 25, global loss: 609.2418874502182 epoch: 26, global loss: 606.9881150126457 epoch: 27, global loss: 604.1266944408417 epoch: 28, global loss: 601.9513965249062 epoch: 29, global loss: 599.3962340056896 epoch: 30, global loss: 597.1397420763969 epoch: 31, global loss: 594.9806322753429 epoch: 32, global loss: 592.4846655130386 epoch: 33, global loss: 590.964960694313 epoch: 34, global loss: 589.1249485313892 epoch: 35, global loss: 587.7177795469761 epoch: 36, global loss: 586.5153777897358 epoch: 37, global loss: 585.0905412733555 epoch: 38, global loss: 583.0929714739323 epoch: 39, global loss: 581.5914468467236 epoch: 40, global loss: 580.2802640795708 epoch: 41, global loss: 578.9509879350662 epoch: 42, global loss: 577.5015483796597 epoch: 43, global loss: 576.4578168094158 epoch: 44, global loss: 575.4466687440872 epoch: 45, global loss: 574.2208180129528 epoch: 46, global loss: 573.175264954567 epoch: 47, global loss: 572.0233966112137 epoch: 48, global loss: 571.2589029669762 epoch: 49, global loss: 569.862211227417 epoch: 50, global loss: 569.0242480933666 epoch: 51, global loss: 568.1133631169796 epoch: 52, global loss: 567.0479202270508 epoch: 53, global loss: 566.3271011710167 epoch: 54, global loss: 565.315014988184 epoch: 55, global loss: 564.2411880791187 epoch: 56, global loss: 563.5841090977192 epoch: 57, global loss: 562.8328421413898 epoch: 58, global loss: 562.2441559433937 epoch: 59, global loss: 561.6786546111107 epoch: 60, global loss: 560.614624530077 epoch: 61, global loss: 559.6966196298599 epoch: 62, global loss: 558.9750942885876 epoch: 63, global loss: 558.2324139177799 epoch: 64, global loss: 557.0759356617928 epoch: 65, global loss: 556.6331144869328 epoch: 66, global loss: 556.2212075293064 epoch: 67, global loss: 555.7944973111153 epoch: 68, global loss: 555.2471287250519 epoch: 69, global loss: 554.3389156460762 epoch: 70, global loss: 553.372810870409 epoch: 71, global loss: 552.5664614439011 epoch: 72, global loss: 552.2058208882809 epoch: 73, global loss: 551.3575920760632 epoch: 74, global loss: 550.9325566589832 epoch: 75, global loss: 550.135292172432 epoch: 76, global loss: 549.5429366230965 epoch: 77, global loss: 548.9127770960331 epoch: 78, global loss: 548.0797218978405 epoch: 79, global loss: 547.5039230287075 epoch: 80, global loss: 546.9741947650909 epoch: 81, global loss: 546.7496129572392 epoch: 82, global loss: 545.9632824659348 epoch: 83, global loss: 545.3648994863033 epoch: 84, global loss: 544.9108284711838 epoch: 85, global loss: 544.0350824892521 epoch: 86, global loss: 543.4242036342621 epoch: 87, global loss: 542.7037590742111 epoch: 88, global loss: 542.1161627173424 epoch: 89, global loss: 541.8322369754314 epoch: 90, global loss: 541.3161093592644 epoch: 91, global loss: 540.885465234518 epoch: 92, global loss: 540.5963389277458 epoch: 93, global loss: 539.9255212843418 epoch: 94, global loss: 539.3829241394997 epoch: 95, global loss: 539.0431137979031 epoch: 96, global loss: 538.7173920273781 epoch: 97, global loss: 538.0091341137886 epoch: 98, global loss: 537.6660176515579 epoch: 99, global loss: 536.9482679069042
# torch.save(gi.state_dict(),"./model/giwithattention(cls=10)_epoch500_500img.pkl")
gi.to("cpu")
att_model.to("cpu")
index = 100
sample = trainset[index][0]
plt.imshow(sample.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
import math
pred = att_model(sample.view(-1,3,32,32))
recon_list = []
classnum_in_gi = 10
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([0]))
grad = torch.autograd.grad(loss, att_model.parameters(), retain_graph=True)
grad_input = grad[-2][:classnum_in_gi].reshape(classnum_in_gi, 768, 1, 1)
recons, weight = gi(grad_input)
plt.imshow(recons.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
weight
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000], grad_fn=<SqueezeBackward0>)
gi = GradinversionWithselfAttention()
gi.load_state_dict(torch.load("./model/giwithattention(cls=10)_epoch500_500img.pkl"))
class_num = 10
att_model = LeNetZhu(num_classes = class_num, num_channels=3)
gi.to("cpu")
att_model.to("cpu")
LeNetZhu( (body): Sequential( (0): Conv2d(3, 12, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)) (1): Sigmoid() (2): Conv2d(12, 12, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)) (3): Sigmoid() (4): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (5): Sigmoid() ) (fc): Sequential( (0): Linear(in_features=768, out_features=10, bias=True) ) )
index = 1000
sample = trainset[index][0]
plt.imshow(sample.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
import math
pred = att_model(sample.view(-1,3,32,32))
recon_list = []
classnum_in_gi = 10
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([j]))
grad = torch.autograd.grad(loss, att_model.parameters(), retain_graph=True)
grad_input = grad[-2][:classnum_in_gi].reshape(classnum_in_gi, 768, 1, 1)
recons, weight = gi(grad_input)
plt.imshow(recons.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
weight 竟然全是一样的
weight
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000], grad_fn=<SqueezeBackward0>)
class_num = 10
att_model = LeNetZhu(num_classes = class_num, num_channels=3)
gi = GradinversionWithselfAttention()
optimizer = torch.optim.Adam(gi.parameters(), lr = 0.01)
def train(grad_model, gi_model, trainset, optimizer, loss_fn, device, num_train, class_num=10):
grad_model.to(device)
gi_model.to(device)
global_loss = 0
for i in range(num_train): # len(trainset)
sample = trainset[i][0].to(device)
pred = grad_model(sample.view(-1,3,32,32))
r_loss = 0
for j in range(1):
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([j]).to(device))
grad = torch.autograd.grad(loss, grad_model.parameters(), retain_graph=True)
grad_input = grad[-2].reshape(class_num, 768, 1, 1).to(device)
optimizer.zero_grad()
reconstruction_loss = loss_fn(gi_model(grad_input)[0], sample)
reconstruction_loss.backward()
optimizer.step()
r_loss += reconstruction_loss.item()
global_loss += r_loss
# print(f"i : {i}, Loss : {r_loss}")
return global_loss
hist = []
for epoch in range(500):
global_loss = train(att_model, gi, trainset, optimizer, torch.nn.functional.binary_cross_entropy, "cuda",
num_train=500, class_num=class_num)
print(f"epoch: {epoch}, global loss: {global_loss}")
hist.append(global_loss)
epoch: 0, global loss: 338.60025918483734 epoch: 1, global loss: 337.04458051919937 epoch: 2, global loss: 336.44213062524796 epoch: 3, global loss: 336.1074105501175 epoch: 4, global loss: 335.4954474568367 epoch: 5, global loss: 333.28484576940536 epoch: 6, global loss: 327.5770261287689 epoch: 7, global loss: 315.8735587000847 epoch: 8, global loss: 312.13780242204666 epoch: 9, global loss: 309.8856797218323 epoch: 10, global loss: 308.7309779524803 epoch: 11, global loss: 307.7558245062828 epoch: 12, global loss: 306.9349880218506 epoch: 13, global loss: 306.3899496495724 epoch: 14, global loss: 305.58511874079704 epoch: 15, global loss: 305.03159829974174 epoch: 16, global loss: 304.236841827631 epoch: 17, global loss: 302.9832682609558 epoch: 18, global loss: 301.91884484887123 epoch: 19, global loss: 301.0420523583889 epoch: 20, global loss: 299.63857024908066 epoch: 21, global loss: 298.28356659412384 epoch: 22, global loss: 297.2090891599655 epoch: 23, global loss: 296.08047011494637 epoch: 24, global loss: 295.02819406986237 epoch: 25, global loss: 294.00377002358437 epoch: 26, global loss: 293.2163531780243 epoch: 27, global loss: 292.47331166267395 epoch: 28, global loss: 291.99924275279045 epoch: 29, global loss: 291.5046341121197 epoch: 30, global loss: 291.0650665163994 epoch: 31, global loss: 290.5462442934513 epoch: 32, global loss: 290.1719908118248 epoch: 33, global loss: 289.7578994035721 epoch: 34, global loss: 289.2986038327217 epoch: 35, global loss: 288.93326476216316 epoch: 36, global loss: 288.58343267440796 epoch: 37, global loss: 288.1862127184868 epoch: 38, global loss: 287.7282990217209 epoch: 39, global loss: 287.3022213280201 epoch: 40, global loss: 286.9106777906418 epoch: 41, global loss: 286.4804193377495 epoch: 42, global loss: 286.1487970650196 epoch: 43, global loss: 285.7585299015045 epoch: 44, global loss: 285.3255044221878 epoch: 45, global loss: 285.1598679125309 epoch: 46, global loss: 284.90957736968994 epoch: 47, global loss: 284.55337381362915 epoch: 48, global loss: 284.16781136393547 epoch: 49, global loss: 283.7935930490494 epoch: 50, global loss: 283.3358548283577 epoch: 51, global loss: 283.00090113282204 epoch: 52, global loss: 282.7107480764389 epoch: 53, global loss: 282.52693235874176 epoch: 54, global loss: 282.2320607602596 epoch: 55, global loss: 282.04321256279945 epoch: 56, global loss: 281.91945818066597 epoch: 57, global loss: 281.462922334671 epoch: 58, global loss: 281.0816715657711 epoch: 59, global loss: 281.0037013590336 epoch: 60, global loss: 280.70009222626686 epoch: 61, global loss: 280.47275269031525 epoch: 62, global loss: 280.0947594642639 epoch: 63, global loss: 280.0830429792404 epoch: 64, global loss: 280.38337859511375 epoch: 65, global loss: 280.2586654126644 epoch: 66, global loss: 279.9976105093956 epoch: 67, global loss: 279.58775195479393 epoch: 68, global loss: 279.328746765852 epoch: 69, global loss: 279.49448135495186 epoch: 70, global loss: 279.1056619286537 epoch: 71, global loss: 278.58737483620644 epoch: 72, global loss: 278.2464325428009 epoch: 73, global loss: 278.1430314183235 epoch: 74, global loss: 278.08497819304466 epoch: 75, global loss: 278.2858307361603 epoch: 76, global loss: 277.9487242400646 epoch: 77, global loss: 277.8523486852646 epoch: 78, global loss: 277.9532396197319 epoch: 79, global loss: 277.861212015152 epoch: 80, global loss: 277.8161629140377 epoch: 81, global loss: 277.36342176795006 epoch: 82, global loss: 276.9804410636425 epoch: 83, global loss: 277.0553270280361 epoch: 84, global loss: 276.8374452292919 epoch: 85, global loss: 276.5743753015995 epoch: 86, global loss: 276.84409457445145 epoch: 87, global loss: 277.1064626276493 epoch: 88, global loss: 277.16203156113625 epoch: 89, global loss: 276.93172401189804 epoch: 90, global loss: 276.5795137286186 epoch: 91, global loss: 276.48773965239525 epoch: 92, global loss: 276.2845629155636 epoch: 93, global loss: 276.150188177824 epoch: 94, global loss: 275.72624921798706 epoch: 95, global loss: 275.5478581190109 epoch: 96, global loss: 275.6204836666584 epoch: 97, global loss: 275.8021914958954 epoch: 98, global loss: 275.4928195774555 epoch: 99, global loss: 275.23908081650734 epoch: 100, global loss: 274.8106293082237 epoch: 101, global loss: 274.44674864411354 epoch: 102, global loss: 274.322762131691 epoch: 103, global loss: 274.5765192806721 epoch: 104, global loss: 274.92705154418945 epoch: 105, global loss: 274.5397770702839 epoch: 106, global loss: 274.1805611848831 epoch: 107, global loss: 274.21314546465874 epoch: 108, global loss: 274.10285317897797 epoch: 109, global loss: 273.71711441874504 epoch: 110, global loss: 273.47996377944946 epoch: 111, global loss: 273.2850476205349 epoch: 112, global loss: 273.4228249490261 epoch: 113, global loss: 273.8601565659046 epoch: 114, global loss: 273.85707265138626 epoch: 115, global loss: 273.79134145379066 epoch: 116, global loss: 273.25904843211174 epoch: 117, global loss: 273.0848714709282 epoch: 118, global loss: 272.90549525618553 epoch: 119, global loss: 272.6331161260605 epoch: 120, global loss: 272.9493053853512 epoch: 121, global loss: 272.9758483171463 epoch: 122, global loss: 272.8044936954975 epoch: 123, global loss: 272.7653195261955 epoch: 124, global loss: 272.9390612542629 epoch: 125, global loss: 273.0032024681568 epoch: 126, global loss: 272.92705553770065 epoch: 127, global loss: 272.7667332291603 epoch: 128, global loss: 272.19098204374313 epoch: 129, global loss: 272.204122364521 epoch: 130, global loss: 272.43800419569016 epoch: 131, global loss: 272.3421573340893 epoch: 132, global loss: 272.2290526032448 epoch: 133, global loss: 272.59938633441925 epoch: 134, global loss: 272.00171104073524 epoch: 135, global loss: 271.30085161328316 epoch: 136, global loss: 270.9739597737789 epoch: 137, global loss: 271.12417608499527 epoch: 138, global loss: 271.5175721347332 epoch: 139, global loss: 271.37043672800064 epoch: 140, global loss: 271.2331156730652 epoch: 141, global loss: 271.10730946063995 epoch: 142, global loss: 271.01866939663887 epoch: 143, global loss: 270.95120072364807 epoch: 144, global loss: 271.14853858947754 epoch: 145, global loss: 271.0423741340637 epoch: 146, global loss: 270.73897886276245 epoch: 147, global loss: 270.97135388851166 epoch: 148, global loss: 271.18087002635 epoch: 149, global loss: 271.35382464528084 epoch: 150, global loss: 271.01262044906616 epoch: 151, global loss: 270.678875207901 epoch: 152, global loss: 270.6979446709156 epoch: 153, global loss: 270.9091091156006 epoch: 154, global loss: 270.8644742965698 epoch: 155, global loss: 270.7029717564583 epoch: 156, global loss: 270.4299880862236 epoch: 157, global loss: 270.303808927536 epoch: 158, global loss: 270.3775800168514 epoch: 159, global loss: 270.25759878754616 epoch: 160, global loss: 269.82628360390663 epoch: 161, global loss: 269.96221193671227 epoch: 162, global loss: 269.7560719549656 epoch: 163, global loss: 269.52191615104675 epoch: 164, global loss: 269.96368330717087 epoch: 165, global loss: 269.6465944349766 epoch: 166, global loss: 269.55197739601135 epoch: 167, global loss: 269.4696534574032 epoch: 168, global loss: 269.0192313492298 epoch: 169, global loss: 269.04565793275833 epoch: 170, global loss: 268.923408806324 epoch: 171, global loss: 268.89046081900597 epoch: 172, global loss: 269.1728350520134 epoch: 173, global loss: 269.2761588692665 epoch: 174, global loss: 269.0289306342602 epoch: 175, global loss: 268.56900057196617 epoch: 176, global loss: 268.7576342821121 epoch: 177, global loss: 268.982787579298 epoch: 178, global loss: 269.081582903862 epoch: 179, global loss: 268.71396151185036 epoch: 180, global loss: 268.5042599141598 epoch: 181, global loss: 268.28460681438446 epoch: 182, global loss: 268.2354346215725 epoch: 183, global loss: 268.65841230750084 epoch: 184, global loss: 268.355104804039 epoch: 185, global loss: 268.6642047762871 epoch: 186, global loss: 268.62434473633766 epoch: 187, global loss: 267.95514500141144 epoch: 188, global loss: 268.001448482275 epoch: 189, global loss: 268.28792253136635 epoch: 190, global loss: 268.68641623854637 epoch: 191, global loss: 268.25121665000916 epoch: 192, global loss: 268.0614136457443 epoch: 193, global loss: 267.9363207221031 epoch: 194, global loss: 267.8374087512493 epoch: 195, global loss: 268.03878623247147 epoch: 196, global loss: 268.19674867391586 epoch: 197, global loss: 268.499749571085 epoch: 198, global loss: 267.73247745633125 epoch: 199, global loss: 267.6727564036846 epoch: 200, global loss: 267.9237776696682 epoch: 201, global loss: 267.8691616356373 epoch: 202, global loss: 267.8297538757324 epoch: 203, global loss: 267.61287048459053 epoch: 204, global loss: 267.3559074997902 epoch: 205, global loss: 267.4039722084999 epoch: 206, global loss: 267.6859909892082 epoch: 207, global loss: 267.60051518678665 epoch: 208, global loss: 267.9663180410862 epoch: 209, global loss: 267.8091093301773 epoch: 210, global loss: 267.46151918172836 epoch: 211, global loss: 267.109820663929 epoch: 212, global loss: 267.088393419981 epoch: 213, global loss: 267.0687314271927 epoch: 214, global loss: 267.2452138364315 epoch: 215, global loss: 267.2136981189251 epoch: 216, global loss: 267.07490861415863 epoch: 217, global loss: 267.01893919706345 epoch: 218, global loss: 266.8023644685745 epoch: 219, global loss: 266.69594609737396 epoch: 220, global loss: 266.7091612815857 epoch: 221, global loss: 266.72398322820663 epoch: 222, global loss: 266.51622980833054 epoch: 223, global loss: 266.3694181740284 epoch: 224, global loss: 266.49278926849365 epoch: 225, global loss: 266.42015275359154 epoch: 226, global loss: 266.8500148653984 epoch: 227, global loss: 266.68752256035805 epoch: 228, global loss: 266.454810321331 epoch: 229, global loss: 266.66304737329483 epoch: 230, global loss: 266.4411944448948 epoch: 231, global loss: 266.5573633015156 epoch: 232, global loss: 266.67343243956566 epoch: 233, global loss: 266.4969413280487 epoch: 234, global loss: 266.5998729467392 epoch: 235, global loss: 266.7725577056408 epoch: 236, global loss: 266.4613431096077 epoch: 237, global loss: 266.6227987706661 epoch: 238, global loss: 266.42022210359573 epoch: 239, global loss: 266.54845687747 epoch: 240, global loss: 266.48305439949036 epoch: 241, global loss: 266.2867089807987 epoch: 242, global loss: 266.26241785287857 epoch: 243, global loss: 266.8284336924553 epoch: 244, global loss: 266.27072790265083 epoch: 245, global loss: 265.87149888277054 epoch: 246, global loss: 265.70375031232834 epoch: 247, global loss: 265.4910563826561 epoch: 248, global loss: 265.5325318276882 epoch: 249, global loss: 265.62069869041443 epoch: 250, global loss: 265.88404163718224 epoch: 251, global loss: 266.2039574086666 epoch: 252, global loss: 265.983141630888 epoch: 253, global loss: 265.4022875726223 epoch: 254, global loss: 265.30491492152214 epoch: 255, global loss: 265.1420873105526 epoch: 256, global loss: 265.22370091080666 epoch: 257, global loss: 265.4125193655491 epoch: 258, global loss: 265.8117779493332 epoch: 259, global loss: 265.7197071015835 epoch: 260, global loss: 265.57474479079247 epoch: 261, global loss: 265.2903926372528 epoch: 262, global loss: 265.40443524718285 epoch: 263, global loss: 265.3066259920597 epoch: 264, global loss: 265.28966122865677 epoch: 265, global loss: 265.1402108669281 epoch: 266, global loss: 265.5858030617237 epoch: 267, global loss: 265.55771067738533 epoch: 268, global loss: 265.4581235945225 epoch: 269, global loss: 265.47352263331413 epoch: 270, global loss: 265.15094515681267 epoch: 271, global loss: 265.29064720869064 epoch: 272, global loss: 265.4582040011883 epoch: 273, global loss: 265.4547881782055 epoch: 274, global loss: 265.6645033955574 epoch: 275, global loss: 265.5247828066349 epoch: 276, global loss: 265.59850522875786 epoch: 277, global loss: 265.33320939540863 epoch: 278, global loss: 265.09297037124634 epoch: 279, global loss: 265.1541801095009 epoch: 280, global loss: 265.00482872128487 epoch: 281, global loss: 265.03508749604225 epoch: 282, global loss: 265.2006196677685 epoch: 283, global loss: 264.84514731168747 epoch: 284, global loss: 264.47645822167397 epoch: 285, global loss: 264.5370977818966 epoch: 286, global loss: 264.6341909468174 epoch: 287, global loss: 264.5076683461666 epoch: 288, global loss: 264.5891064405441 epoch: 289, global loss: 264.5197205245495 epoch: 290, global loss: 264.4897865355015 epoch: 291, global loss: 264.34834283590317 epoch: 292, global loss: 264.36202171444893 epoch: 293, global loss: 264.1250133216381 epoch: 294, global loss: 264.3029277920723 epoch: 295, global loss: 264.4219064414501 epoch: 296, global loss: 264.350359916687 epoch: 297, global loss: 264.1961720883846 epoch: 298, global loss: 264.22130024433136 epoch: 299, global loss: 264.08995401859283 epoch: 300, global loss: 264.42928609251976 epoch: 301, global loss: 264.5910578966141 epoch: 302, global loss: 264.2654378414154 epoch: 303, global loss: 263.9419307410717 epoch: 304, global loss: 263.87533032894135 epoch: 305, global loss: 264.1782941222191 epoch: 306, global loss: 264.40385088324547 epoch: 307, global loss: 264.2409340441227 epoch: 308, global loss: 264.06782975792885 epoch: 309, global loss: 264.0442868769169 epoch: 310, global loss: 264.11002841591835 epoch: 311, global loss: 263.82124987244606 epoch: 312, global loss: 263.9848325550556 epoch: 313, global loss: 264.03957572579384 epoch: 314, global loss: 264.1016899943352 epoch: 315, global loss: 263.90654093027115 epoch: 316, global loss: 263.80911788344383 epoch: 317, global loss: 263.8586626648903 epoch: 318, global loss: 264.04839649796486 epoch: 319, global loss: 264.0237063765526 epoch: 320, global loss: 263.73895341157913 epoch: 321, global loss: 264.0177748799324 epoch: 322, global loss: 264.0644153356552 epoch: 323, global loss: 263.8700003027916 epoch: 324, global loss: 263.6515619158745 epoch: 325, global loss: 263.75360587239265 epoch: 326, global loss: 263.70011165738106 epoch: 327, global loss: 263.64447075128555 epoch: 328, global loss: 263.414237678051 epoch: 329, global loss: 263.4543554186821 epoch: 330, global loss: 263.8412399291992 epoch: 331, global loss: 263.76962211728096 epoch: 332, global loss: 263.74514469504356 epoch: 333, global loss: 263.5174780189991 epoch: 334, global loss: 263.65144753456116 epoch: 335, global loss: 263.9130415916443 epoch: 336, global loss: 263.9026415646076 epoch: 337, global loss: 263.77352157235146 epoch: 338, global loss: 263.61850094795227 epoch: 339, global loss: 263.4765903055668 epoch: 340, global loss: 263.0621617138386 epoch: 341, global loss: 263.2357193529606 epoch: 342, global loss: 263.27302518486977 epoch: 343, global loss: 263.4739535152912 epoch: 344, global loss: 263.5399603545666 epoch: 345, global loss: 263.0525557100773 epoch: 346, global loss: 263.2581130862236 epoch: 347, global loss: 263.3414052724838 epoch: 348, global loss: 262.98551815748215 epoch: 349, global loss: 263.0038081407547 epoch: 350, global loss: 262.9502255618572 epoch: 351, global loss: 263.1234270334244 epoch: 352, global loss: 263.4570249915123 epoch: 353, global loss: 263.3413445651531 epoch: 354, global loss: 263.142136991024 epoch: 355, global loss: 262.99542793631554 epoch: 356, global loss: 262.8281972706318 epoch: 357, global loss: 262.808453142643 epoch: 358, global loss: 262.7352406382561 epoch: 359, global loss: 262.7309219241142 epoch: 360, global loss: 263.0659875869751 epoch: 361, global loss: 263.05039969086647 epoch: 362, global loss: 262.9881289601326 epoch: 363, global loss: 262.84519958496094 epoch: 364, global loss: 262.84456384181976 epoch: 365, global loss: 262.7072714269161 epoch: 366, global loss: 262.7783677279949 epoch: 367, global loss: 262.9154816865921 epoch: 368, global loss: 262.7148408293724 epoch: 369, global loss: 262.9098824560642 epoch: 370, global loss: 262.9107674062252 epoch: 371, global loss: 263.01483511924744 epoch: 372, global loss: 262.67527481913567 epoch: 373, global loss: 262.5063259899616 epoch: 374, global loss: 262.5769057273865 epoch: 375, global loss: 262.366546690464 epoch: 376, global loss: 262.4547892808914 epoch: 377, global loss: 262.40500324964523 epoch: 378, global loss: 262.6624390780926 epoch: 379, global loss: 262.8577553629875 epoch: 380, global loss: 262.8830107450485 epoch: 381, global loss: 262.6342525780201 epoch: 382, global loss: 262.2965740263462 epoch: 383, global loss: 262.3039802312851 epoch: 384, global loss: 262.53798374533653 epoch: 385, global loss: 262.7582617402077 epoch: 386, global loss: 262.8510323166847 epoch: 387, global loss: 262.71864822506905 epoch: 388, global loss: 262.67944303154945 epoch: 389, global loss: 262.5314585864544 epoch: 390, global loss: 262.7438705265522 epoch: 391, global loss: 262.9875753223896 epoch: 392, global loss: 262.85370460152626 epoch: 393, global loss: 262.790315926075 epoch: 394, global loss: 262.61300280690193 epoch: 395, global loss: 262.6014828681946 epoch: 396, global loss: 262.19198375940323 epoch: 397, global loss: 262.0079257488251 epoch: 398, global loss: 262.2615749537945 epoch: 399, global loss: 262.61439684033394 epoch: 400, global loss: 262.5630381703377 epoch: 401, global loss: 262.4550834298134 epoch: 402, global loss: 262.22085708379745 epoch: 403, global loss: 261.95099729299545 epoch: 404, global loss: 261.91884088516235 epoch: 405, global loss: 262.25828567147255 epoch: 406, global loss: 262.40313479304314 epoch: 407, global loss: 262.0523853600025 epoch: 408, global loss: 262.05873396992683 epoch: 409, global loss: 262.09865522384644 epoch: 410, global loss: 262.42329743504524 epoch: 411, global loss: 262.73788130283356 epoch: 412, global loss: 262.75496768951416 epoch: 413, global loss: 262.6305637359619 epoch: 414, global loss: 262.3641494810581 epoch: 415, global loss: 262.2569977045059 epoch: 416, global loss: 262.32249185442924 epoch: 417, global loss: 262.2327560186386 epoch: 418, global loss: 262.21284368634224 epoch: 419, global loss: 262.2796156704426 epoch: 420, global loss: 262.2471624612808 epoch: 421, global loss: 262.0672294795513 epoch: 422, global loss: 262.2383571565151 epoch: 423, global loss: 262.3979557454586 epoch: 424, global loss: 262.2135796546936 epoch: 425, global loss: 262.2017630636692 epoch: 426, global loss: 262.0764892101288 epoch: 427, global loss: 261.9091183543205 epoch: 428, global loss: 261.98374941945076 epoch: 429, global loss: 261.90986746549606 epoch: 430, global loss: 261.8498991429806 epoch: 431, global loss: 262.2920818030834 epoch: 432, global loss: 262.3755798339844 epoch: 433, global loss: 262.07216972112656 epoch: 434, global loss: 261.8689458370209 epoch: 435, global loss: 261.9564778804779 epoch: 436, global loss: 261.9558102488518 epoch: 437, global loss: 261.97024843096733 epoch: 438, global loss: 261.9320366382599 epoch: 439, global loss: 261.8419044315815 epoch: 440, global loss: 261.6301702260971 epoch: 441, global loss: 261.82230415940285 epoch: 442, global loss: 261.85898983478546 epoch: 443, global loss: 261.57121154665947 epoch: 444, global loss: 261.56499153375626 epoch: 445, global loss: 261.87683233618736 epoch: 446, global loss: 261.889323592186 epoch: 447, global loss: 261.9064636528492 epoch: 448, global loss: 261.67909309268 epoch: 449, global loss: 261.71322786808014 epoch: 450, global loss: 262.01076209545135 epoch: 451, global loss: 261.9387253522873 epoch: 452, global loss: 261.56914725899696 epoch: 453, global loss: 261.56736543774605 epoch: 454, global loss: 261.6368360221386 epoch: 455, global loss: 261.44825714826584 epoch: 456, global loss: 261.62926894426346 epoch: 457, global loss: 261.5140478014946 epoch: 458, global loss: 261.71411910653114 epoch: 459, global loss: 261.6664963066578 epoch: 460, global loss: 261.64514419436455 epoch: 461, global loss: 261.66424775123596 epoch: 462, global loss: 261.51387003064156 epoch: 463, global loss: 261.4502768814564 epoch: 464, global loss: 261.5433467030525 epoch: 465, global loss: 261.5068618655205 epoch: 466, global loss: 261.49896973371506 epoch: 467, global loss: 261.48572209477425 epoch: 468, global loss: 261.5729942023754 epoch: 469, global loss: 261.51336818933487 epoch: 470, global loss: 261.5104151368141 epoch: 471, global loss: 261.5890080332756 epoch: 472, global loss: 261.40349024534225 epoch: 473, global loss: 261.24343913793564 epoch: 474, global loss: 261.36492812633514 epoch: 475, global loss: 261.3303753733635 epoch: 476, global loss: 261.47334709763527 epoch: 477, global loss: 261.5734957754612 epoch: 478, global loss: 261.6231796145439 epoch: 479, global loss: 261.4435556232929 epoch: 480, global loss: 261.23709240555763 epoch: 481, global loss: 261.1772166490555 epoch: 482, global loss: 261.2189137041569 epoch: 483, global loss: 261.3589496612549 epoch: 484, global loss: 261.42102670669556 epoch: 485, global loss: 261.3118377029896 epoch: 486, global loss: 261.3803126513958 epoch: 487, global loss: 261.258409768343 epoch: 488, global loss: 261.3234350979328 epoch: 489, global loss: 261.4172223210335 epoch: 490, global loss: 261.6860262155533 epoch: 491, global loss: 261.4811144769192 epoch: 492, global loss: 261.18531879782677 epoch: 493, global loss: 261.0046935081482 epoch: 494, global loss: 261.13155138492584 epoch: 495, global loss: 261.16701820492744 epoch: 496, global loss: 261.1627837717533 epoch: 497, global loss: 261.1643570959568 epoch: 498, global loss: 261.1084181666374 epoch: 499, global loss: 261.2531746029854
torch.save(gi.state_dict(),"./model/giwithattention(cls=10_trainwith1)_epoch500_500img.pkl")
gi.to("cpu")
att_model.to("cpu")
index = 1000
sample = trainset[index][0]
plt.imshow(sample.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
import math
pred = att_model(sample.view(-1,3,32,32))
recon_list = []
classnum_in_gi = 10
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([j]))
grad = torch.autograd.grad(loss, att_model.parameters(), retain_graph=True)
grad_input = grad[-2][:classnum_in_gi].reshape(classnum_in_gi, 768, 1, 1)
recons, weight = gi(grad_input)
plt.imshow(recons.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
weight
tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000], grad_fn=<SqueezeBackward0>)
import torchvision
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import pickle
import numpy as np
import sys
sys.path.append("..")
from data.dataset import construct_datasets
from model.model import construct_model, set_random_seed
import argparse
from collections import defaultdict
import torch.nn as nn
from torchvision.datasets.folder import ImageFolder
import matplotlib.pyplot as plt
from modellib import LeNetZhu
default_transform = transforms.Compose([transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor()])
trainset = ImageFolder("../data/celeba/img_align_celeba/split_celeba/Train", transform = default_transform)
valset = ImageFolder("../data/celeba/img_align_celeba/split_celeba/Test", transform = default_transform)
ngf = 64
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(768, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.SELU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.SELU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.SELU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.SELU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, 3, 4, 2, 1, bias=False),
nn.BatchNorm2d(3),
nn.Sigmoid(),
# nn.Tanh()
# state size. (nc) x 64 x 64
nn.MaxPool2d(2,2)
)
def forward(self, input): # [1, 768, 1, 1]
return self.main(input)
class ConvFeatureExtraction(nn.Module):
def __init__(self):
super(ConvFeatureExtraction, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3, stride=2)
self.conv2 = nn.Conv2d(1, 1, 3, stride=1)
self.linear1 = nn.Linear(169, 128)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.flatten().view(x.size()[0], -1)
x = self.linear1(x)
return x
class GradinversionWithselfAttention(nn.Module):
def __init__(self):
super(GradinversionWithselfAttention, self).__init__()
self.generator = Generator()
self.feature = ConvFeatureExtraction()
self.softmax = nn.Softmax(dim=1)
self.Q = nn.Linear(128, 128)
self.K = nn.Linear(128, 128)
def forward(self, x): # x: torch.Size([10, 768])
x = x.reshape(x.size()[0], x.size()[1], 1, 1)
x = self.generator(x)
qk = self.feature(x)
qk = qk.unsqueeze(dim=0)
q = self.Q(qk)
k = self.K(qk)
score = torch.bmm(q, k.permute(0,2,1))
weight = self.softmax(torch.sum(score, dim = 2))
tmp = (weight.reshape(10,1,1,1) * x)
weight_sum = torch.sum(tmp, dim=0)
return weight_sum, weight
def train(grad_model, gi_model, trainset, optimizer, loss_fn, device, num_train, class_num=10):
grad_model.to(device)
gi_model.to(device)
global_loss = 0
for i in range(num_train): # len(trainset)
sample = trainset[i][0].to(device)
pred = grad_model(sample.view(-1,3,32,32))
r_loss = 0
for j in range(class_num):
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([j]).to(device))
grad = torch.autograd.grad(loss, grad_model.parameters(), retain_graph=True)
grad_input = grad[-2].reshape(class_num, 768, 1, 1).to(device)
optimizer.zero_grad()
reconstruction_loss = loss_fn(gi_model(grad_input)[0], sample)
reconstruction_loss.backward()
optimizer.step()
r_loss += reconstruction_loss.item()
global_loss += r_loss
# print(f"i : {i}, Loss : {r_loss}")
return global_loss
class_num = 10
att_model = LeNetZhu(num_classes = class_num, num_channels=3)
gi = GradinversionWithselfAttention()
optimizer = torch.optim.Adam(gi.parameters(), lr = 0.0001)
hist = []
for epoch in range(300):
global_loss = train(att_model, gi, trainset, optimizer, torch.nn.functional.binary_cross_entropy, "cuda",
num_train=100, class_num=class_num)
print(f"epoch: {epoch}, global loss: {global_loss}")
hist.append(global_loss)
epoch: 0, global loss: 674.4825942516327 epoch: 1, global loss: 667.1935893297195 epoch: 2, global loss: 665.4668490290642 epoch: 3, global loss: 660.592010140419 epoch: 4, global loss: 655.6120321154594 epoch: 5, global loss: 654.4572462439537 epoch: 6, global loss: 652.1957456469536 epoch: 7, global loss: 659.1984488368034 epoch: 8, global loss: 652.4935920834541 epoch: 9, global loss: 664.1487569212914 epoch: 10, global loss: 658.6552041172981 epoch: 11, global loss: 655.8291607499123 epoch: 12, global loss: 653.2029873132706 epoch: 13, global loss: 650.5847187638283 epoch: 14, global loss: 648.1683402657509 epoch: 15, global loss: 645.5536907911301 epoch: 16, global loss: 642.5484381914139 epoch: 17, global loss: 638.8860410451889 epoch: 18, global loss: 645.541214287281 epoch: 19, global loss: 633.081243366003 epoch: 20, global loss: 613.7448983192444 epoch: 21, global loss: 604.4050687849522 epoch: 22, global loss: 599.7048689126968 epoch: 23, global loss: 598.0098409950733 epoch: 24, global loss: 593.4990938007832 epoch: 25, global loss: 590.1116417050362 epoch: 26, global loss: 587.6759895980358 epoch: 27, global loss: 585.6909210085869 epoch: 28, global loss: 583.2531938552856 epoch: 29, global loss: 581.097696274519 epoch: 30, global loss: 582.9426203966141 epoch: 31, global loss: 576.5041388273239 epoch: 32, global loss: 576.292847931385 epoch: 33, global loss: 573.6794521808624 epoch: 34, global loss: 572.3775804936886 epoch: 35, global loss: 571.023720651865 epoch: 36, global loss: 569.6730150282383 epoch: 37, global loss: 568.4992664456367 epoch: 38, global loss: 567.3763771653175 epoch: 39, global loss: 566.2756295502186 epoch: 40, global loss: 565.1369759738445 epoch: 41, global loss: 564.0311997830868 epoch: 42, global loss: 562.9334791898727 epoch: 43, global loss: 561.7785176932812 epoch: 44, global loss: 560.6706030964851 epoch: 45, global loss: 559.6120978593826 epoch: 46, global loss: 558.567091614008 epoch: 47, global loss: 557.5228275954723 epoch: 48, global loss: 556.4870015382767 epoch: 49, global loss: 555.5302931666374 epoch: 50, global loss: 554.5933721661568 epoch: 51, global loss: 553.6184344887733 epoch: 52, global loss: 552.800593405962 epoch: 53, global loss: 551.7529977858067 epoch: 54, global loss: 550.8304800391197 epoch: 55, global loss: 549.9011229276657 epoch: 56, global loss: 548.932663500309 epoch: 57, global loss: 547.9840379953384 epoch: 58, global loss: 547.0549567341805 epoch: 59, global loss: 546.1065548658371 epoch: 60, global loss: 545.1039568185806 epoch: 61, global loss: 544.1431631445885 epoch: 62, global loss: 543.1279513835907 epoch: 63, global loss: 541.9938332438469 epoch: 64, global loss: 540.8761181533337 epoch: 65, global loss: 539.7597238123417 epoch: 66, global loss: 539.1635148227215 epoch: 67, global loss: 537.5679467916489 epoch: 68, global loss: 536.4996035695076 epoch: 69, global loss: 535.4936526119709 epoch: 70, global loss: 534.4467121958733 epoch: 71, global loss: 533.3590206205845 epoch: 72, global loss: 532.3147319853306 epoch: 73, global loss: 531.2508494853973 epoch: 74, global loss: 530.1926727890968 epoch: 75, global loss: 529.1155897378922 epoch: 76, global loss: 528.1800818741322 epoch: 77, global loss: 527.2117345333099 epoch: 78, global loss: 526.2947707772255 epoch: 79, global loss: 525.5252906382084 epoch: 80, global loss: 524.7096827626228 epoch: 81, global loss: 523.7984044849873 epoch: 82, global loss: 522.78268840909 epoch: 83, global loss: 521.7484927773476 epoch: 84, global loss: 520.8880970478058 epoch: 85, global loss: 519.9399253726006 epoch: 86, global loss: 519.0214531719685 epoch: 87, global loss: 518.1659171879292 epoch: 88, global loss: 517.2928357720375 epoch: 89, global loss: 516.4177814126015 epoch: 90, global loss: 515.5256786346436 epoch: 91, global loss: 514.766328305006 epoch: 92, global loss: 514.0889685451984 epoch: 93, global loss: 513.4248222112656 epoch: 94, global loss: 512.8396798968315 epoch: 95, global loss: 512.2728424966335 epoch: 96, global loss: 511.705085337162 epoch: 97, global loss: 511.11995631456375 epoch: 98, global loss: 510.82659527659416 epoch: 99, global loss: 510.0349553525448 epoch: 100, global loss: 509.53576269745827 epoch: 101, global loss: 509.02493235468864 epoch: 102, global loss: 508.5308871269226 epoch: 103, global loss: 508.11364594101906 epoch: 104, global loss: 507.6910216808319 epoch: 105, global loss: 507.3408246040344 epoch: 106, global loss: 507.0248809158802 epoch: 107, global loss: 506.75064131617546 epoch: 108, global loss: 506.49183267354965 epoch: 109, global loss: 506.17414128780365 epoch: 110, global loss: 505.8185847401619 epoch: 111, global loss: 505.43842670321465 epoch: 112, global loss: 505.066110342741 epoch: 113, global loss: 504.76938274502754 epoch: 114, global loss: 504.5238231420517 epoch: 115, global loss: 504.29119166731834 epoch: 116, global loss: 504.07997223734856 epoch: 117, global loss: 503.8681187927723 epoch: 118, global loss: 503.65255653858185 epoch: 119, global loss: 503.4174863100052 epoch: 120, global loss: 503.1654132306576 epoch: 121, global loss: 502.9897632598877 epoch: 122, global loss: 502.7758958339691 epoch: 123, global loss: 502.5622181892395 epoch: 124, global loss: 502.3635467886925 epoch: 125, global loss: 502.1603965163231 epoch: 126, global loss: 501.9763647019863 epoch: 127, global loss: 501.7460001707077 epoch: 128, global loss: 501.4965472817421 epoch: 129, global loss: 501.2906629741192 epoch: 130, global loss: 501.13226383924484 epoch: 131, global loss: 500.93844801187515 epoch: 132, global loss: 500.7943257689476 epoch: 133, global loss: 500.65445777773857 epoch: 134, global loss: 500.47203704714775 epoch: 135, global loss: 500.2710265517235 epoch: 136, global loss: 500.0809588432312 epoch: 137, global loss: 499.96421521902084 epoch: 138, global loss: 499.8568932712078 epoch: 139, global loss: 499.69672188162804 epoch: 140, global loss: 499.54846915602684 epoch: 141, global loss: 499.4482390284538 epoch: 142, global loss: 499.33338513970375 epoch: 143, global loss: 499.22282618284225 epoch: 144, global loss: 499.0941632390022 epoch: 145, global loss: 499.00555112957954 epoch: 146, global loss: 498.9643343985081 epoch: 147, global loss: 498.88275691866875 epoch: 148, global loss: 498.8049059212208 epoch: 149, global loss: 498.7484247982502 epoch: 150, global loss: 498.67636409401894 epoch: 151, global loss: 498.61641108989716 epoch: 152, global loss: 498.4963846206665 epoch: 153, global loss: 498.40352177619934 epoch: 154, global loss: 498.33427825570107 epoch: 155, global loss: 498.2722420990467 epoch: 156, global loss: 498.25645983219147 epoch: 157, global loss: 498.1824731230736 epoch: 158, global loss: 498.1131078004837 epoch: 159, global loss: 498.0111567080021 epoch: 160, global loss: 497.9425509572029 epoch: 161, global loss: 497.9378290474415 epoch: 162, global loss: 497.86262080073357 epoch: 163, global loss: 497.77807688713074 epoch: 164, global loss: 497.65699991583824 epoch: 165, global loss: 497.6025904417038 epoch: 166, global loss: 497.5892086327076 epoch: 167, global loss: 497.55601516366005 epoch: 168, global loss: 497.4651515185833 epoch: 169, global loss: 497.3958241343498 epoch: 170, global loss: 497.2622108757496 epoch: 171, global loss: 497.19583734869957 epoch: 172, global loss: 497.10643684864044 epoch: 173, global loss: 497.0346477329731 epoch: 174, global loss: 496.9301113784313 epoch: 175, global loss: 496.8828657269478 epoch: 176, global loss: 496.8037567138672 epoch: 177, global loss: 496.7726205587387 epoch: 178, global loss: 496.7084252536297 epoch: 179, global loss: 496.66760089993477 epoch: 180, global loss: 496.5988882482052 epoch: 181, global loss: 496.57361620664597 epoch: 182, global loss: 496.51167890429497 epoch: 183, global loss: 496.4756168425083 epoch: 184, global loss: 496.4225699901581 epoch: 185, global loss: 496.40601482987404 epoch: 186, global loss: 496.39382633566856 epoch: 187, global loss: 496.4165682196617 epoch: 188, global loss: 496.46642500162125 epoch: 189, global loss: 496.5348289310932 epoch: 190, global loss: 496.4985382556915 epoch: 191, global loss: 496.4195824563503 epoch: 192, global loss: 496.37232729792595 epoch: 193, global loss: 496.27506789565086 epoch: 194, global loss: 496.1768342256546 epoch: 195, global loss: 496.04688143730164 epoch: 196, global loss: 495.95688885450363 epoch: 197, global loss: 495.8400511443615 epoch: 198, global loss: 495.75393357872963 epoch: 199, global loss: 495.6657531261444 epoch: 200, global loss: 495.6118298768997 epoch: 201, global loss: 495.54069915413857 epoch: 202, global loss: 495.5173283815384 epoch: 203, global loss: 495.4639382362366 epoch: 204, global loss: 495.39965626597404 epoch: 205, global loss: 495.34445440769196 epoch: 206, global loss: 495.2881217300892 epoch: 207, global loss: 495.24329379200935 epoch: 208, global loss: 495.1878437399864 epoch: 209, global loss: 495.1401314139366 epoch: 210, global loss: 495.09846091270447 epoch: 211, global loss: 495.04953932762146 epoch: 212, global loss: 495.0189399123192 epoch: 213, global loss: 494.9832808673382 epoch: 214, global loss: 494.96546483039856 epoch: 215, global loss: 494.92814016342163 epoch: 216, global loss: 494.91906571388245 epoch: 217, global loss: 494.93281358480453 epoch: 218, global loss: 494.9305277466774 epoch: 219, global loss: 494.9359532892704 epoch: 220, global loss: 494.9171520769596 epoch: 221, global loss: 494.8724439740181 epoch: 222, global loss: 494.83282539248466 epoch: 223, global loss: 494.785582870245 epoch: 224, global loss: 494.73427817225456 epoch: 225, global loss: 494.6917574107647 epoch: 226, global loss: 494.68557822704315 epoch: 227, global loss: 494.6706165075302 epoch: 228, global loss: 494.63387057185173 epoch: 229, global loss: 494.6083428263664 epoch: 230, global loss: 494.55811724066734 epoch: 231, global loss: 494.53109580278397 epoch: 232, global loss: 494.47960725426674 epoch: 233, global loss: 494.45812487602234 epoch: 234, global loss: 494.4290888309479 epoch: 235, global loss: 494.41093307733536 epoch: 236, global loss: 494.42146465182304 epoch: 237, global loss: 494.4361335039139 epoch: 238, global loss: 494.439922362566 epoch: 239, global loss: 494.4274563193321 epoch: 240, global loss: 494.42593175172806 epoch: 241, global loss: 494.4385249018669 epoch: 242, global loss: 494.4157117009163 epoch: 243, global loss: 494.3583362698555 epoch: 244, global loss: 494.32936242222786 epoch: 245, global loss: 494.30685544013977 epoch: 246, global loss: 494.29458823800087 epoch: 247, global loss: 494.27481147646904 epoch: 248, global loss: 494.2748574912548 epoch: 249, global loss: 494.27340853214264 epoch: 250, global loss: 494.25723004341125 epoch: 251, global loss: 494.2462451159954 epoch: 252, global loss: 494.22303825616837 epoch: 253, global loss: 494.23058873414993 epoch: 254, global loss: 494.2101854085922 epoch: 255, global loss: 494.2013041675091 epoch: 256, global loss: 494.1888665854931 epoch: 257, global loss: 494.19572001695633 epoch: 258, global loss: 494.2009173333645 epoch: 259, global loss: 494.20802015066147 epoch: 260, global loss: 494.18362244963646 epoch: 261, global loss: 494.16127333045006 epoch: 262, global loss: 494.1818304359913 epoch: 263, global loss: 494.1637741625309 epoch: 264, global loss: 494.1649127304554 epoch: 265, global loss: 494.1519795656204 epoch: 266, global loss: 494.1540858745575 epoch: 267, global loss: 494.1370086967945 epoch: 268, global loss: 494.1586405336857 epoch: 269, global loss: 494.1402297616005 epoch: 270, global loss: 494.1483009457588 epoch: 271, global loss: 494.16792157292366 epoch: 272, global loss: 494.14680847525597 epoch: 273, global loss: 494.1380660235882 epoch: 274, global loss: 494.15593910217285 epoch: 275, global loss: 494.16768592596054 epoch: 276, global loss: 494.21329975128174 epoch: 277, global loss: 494.23703160881996 epoch: 278, global loss: 494.23815140128136 epoch: 279, global loss: 494.18523249030113 epoch: 280, global loss: 494.127505838871 epoch: 281, global loss: 494.07315534353256 epoch: 282, global loss: 494.03050604462624 epoch: 283, global loss: 494.0081321299076 epoch: 284, global loss: 494.0054851770401 epoch: 285, global loss: 494.00201842188835 epoch: 286, global loss: 493.98230481147766 epoch: 287, global loss: 493.9658747911453 epoch: 288, global loss: 493.95692497491837 epoch: 289, global loss: 493.9257173836231 epoch: 290, global loss: 493.9061325788498 epoch: 291, global loss: 493.865954965353 epoch: 292, global loss: 493.8320437371731 epoch: 293, global loss: 493.8257194161415 epoch: 294, global loss: 493.8399261832237 epoch: 295, global loss: 493.82826310396194 epoch: 296, global loss: 493.81327560544014 epoch: 297, global loss: 493.81851002573967 epoch: 298, global loss: 493.82279473543167 epoch: 299, global loss: 493.80466145277023
gi.to("cpu")
att_model.to("cpu")
index = 770
sample = trainset[index][0]
plt.imshow(sample.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
import math
pred = att_model(sample.view(-1,3,32,32))
recon_list = []
classnum_in_gi = 10
loss = nn.CrossEntropyLoss()(pred, torch.LongTensor([7]))
grad = torch.autograd.grad(loss, att_model.parameters(), retain_graph=True)
grad_input = grad[-2][:classnum_in_gi].reshape(classnum_in_gi, 768, 1, 1)
recons, weight = gi(grad_input)
plt.imshow(recons.detach().cpu().numpy().transpose(1,2,0))
plt.axis('off')
plt.show()
plt.close()
weight
tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]], grad_fn=<SoftmaxBackward0>)
grad_input.shape
torch.Size([10, 768, 1, 1])
grad_input.reshape(10, 768)
tensor([[ 3.3652e-04, 3.2886e-04, 3.5680e-04, ..., 3.6786e-04, 3.7276e-04, 3.7532e-04], [-6.6898e-03, -6.5374e-03, -7.0928e-03, ..., -7.3127e-03, -7.4101e-03, -7.4610e-03], [ 2.7396e-06, 2.6772e-06, 2.9047e-06, ..., 2.9947e-06, 3.0346e-06, 3.0555e-06], ..., [ 6.2048e-07, 6.0634e-07, 6.5786e-07, ..., 6.7826e-07, 6.8729e-07, 6.9201e-07], [ 4.0713e-12, 3.9785e-12, 4.3165e-12, ..., 4.4504e-12, 4.5097e-12, 4.5406e-12], [ 1.9508e-03, 1.9063e-03, 2.0683e-03, ..., 2.1324e-03, 2.1608e-03, 2.1757e-03]])
t = grad_input.reshape(10, 768)
torch.sum(t[6])
tensor(0.0002)