논문 링크.
Explaining and Harnessing Adversarial Examples논문에 등장하는 FGSM은 적대적예시 생성 공격 기법의 시초격이기에 매우 유명하고 타 블로그에서 많이다루기도해서 짧게만 설명하고 코드로 들어가보려고합니다.

FGSM

FGSM은 Explaining and Harnessing Adversarial Examples에서 등장한 공격기법이며 기존 모델을 바탕으로 이미지에 noise를 가해서 적대적예시(Adversarial examples)을 만드는 것을 말합니다.

적대적 예시(Adversarial Example)가 무엇이죠?
원래 잘분류되던 이미지를 의도적으로 조작하여 오분류하도록 만든 이미지라고 할 수 있습니다.

noise는 어떻게 추가하나요?
FGSM은 Fast Gradient Sign Method의 약자로 이걸 풀어보자면 빠르고 효율적이면서 기울기의 방향을 고려한 방법이라고 할 수 있을 것 같은데요.

이를 단계적으로 풀어보자면 아래와 같습니다.

  1. 그래디언트 계산: 먼저, 입력 데이터에 대해 모델의 손실 함수의 그래디언트(기울기)를 계산합니다. 이 그래디언트는 모델의 예측에 대해 입력 데이터를 어떻게 변화시키면 손실이 증가하는지를 나타냅니다.

  2. 그래디언트의 부호 취하기: 이 그래디언트의 부호(sign)만을 취합니다. 즉, 각 그래디언트 요소는 그 값이 양수이면 +1, 음수이면 -1로 변환됩니다. 이렇게 하면 방향성만을 남기고 그 크기는 무시하게 됩니다.
    (참고)경사하강법에서 최적화 할때에는 그래디언트 요소가 양수일 때 learning late만큼 음수로 변화시켜서 최소화시켰습니다. FGSM 이 반대 방향으로 부호를 취한다고 할 수 있습니다

  3. 노이즈 생성 및 추가: 이 부호화된 그래디언트에 작은 계수(ε, 엡실론)를 곱하여 노이즈를 생성합니다. 이후 노이즈를 원래의 입력 데이터에 추가합니다. 이렇게 하면 원래 데이터는 조금씩 변형되어 모델의 예측을 잘못하게 만드는 새로운 데이터(악의적인 예시)가 생성됩니다.

적대적예시를 만드는 방법은 알겠는데 타깃은 어떻게 지정하나요?
모델이 특정 타깃 클래스에 대한 예측으로 손실을 가장 많이 받도록 그래디언트를 계산하고 조정하는 식으로 타깃을 지정할 수 있습니다.

CIFAR-10, ResNet-18, FSGM을 이용해서 타겟이 지정된 Adversarial Example 생성해보기

적대적예시를 생성하기 이전에 먼저 모델을 학습했습니다.
Accuracy: 89.78%

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# CIFAR-10 데이터셋에 대한 전처리
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# CIFAR-10 트레이닝 및 테스트 데이터셋 로드
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# 사전 학습된 ResNet-18 모델 로드 및 수정
model = resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # CIFAR-10의 클래스 수에 맞게 조정

# GPU 사용 설정 (가능한 경우)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 모델 학습을 위한 손실 함수와 옵티마이저 설정
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 모델 학습 함수
def train_model(model, criterion, optimizer, dataloader, num_epochs=5):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader):.4f}')

    print('Finished Training')

# 모델 학습
train_model(model, criterion, optimizer, trainloader)

# 정확도를 계산하는 함수
def calculate_accuracy(model, dataloader):
    correct = 0
    total = 0
    model.eval()  # 평가 모드로 설정
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# 테스트 데이터셋에 대한 정확도 계산
accuracy = calculate_accuracy(model, testloader)
print(f'Accuracy: {accuracy}%')

적대적 예시 생성

# CIFAR-10 클래스 레이블
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 원하는 타겟 레이블 설정
target_label = 1

# 타겟 공격 함수 정의
def targeted_fgsm_attack(image, epsilon, data_grad, target):
    sign_data_grad = data_grad.sign()
    perturbed_image = image - epsilon * sign_data_grad  # 타겟 공격을 위해 빼기 연산 사용
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

# 타겟 공격 성공 이미지 추출 함수
def extract_targeted_attacks(model, dataloader, epsilon, target_label):
    targeted_attacks = []

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        for idx in range(len(labels)):
            label = labels[idx]

            # 레이블이 타겟 레이블과 일치하면 건너뜁니다.
            if label == target_label:
                continue

            # 각 이미지-레이블 쌍을 개별적으로 처리합니다.
            image = images[idx].unsqueeze(0)
            image.requires_grad = True  # 개별 이미지에 대해 gradient 계산 활성화

            # 모델의 원래 예측 확인
            original_output = model(image)
            _, original_pred = original_output.max(1)

            # 원래 모델이 이미지를 올바르게 분류했는지 확인
            if original_pred.item() != label.item():
                continue

            # 타겟 레이블로 설정된 레이블을 사용하여 손실 계산
            target = torch.tensor([target_label], device=device)
            outputs = model(image)

            loss = F.cross_entropy(outputs, target)
            model.zero_grad()
            loss.backward()
            data_grad = image.grad.data

            perturbed_data = targeted_fgsm_attack(image, epsilon, data_grad, target)
            output = model(perturbed_data)
            _, final_pred = output.max(1)

            # 타겟 레이블로 분류되도록 예제 생성
            if final_pred.item() == target_label:
                noise = epsilon * data_grad.sign().squeeze().detach().cpu().numpy()
                targeted_attacks.append((perturbed_data.squeeze().detach().cpu().numpy(), noise, label.item(), final_pred.item()))
                if len(targeted_attacks) >= 10:
                    break

    return targeted_attacks


epsilon = 0.01  # FGSM 공격에 사용할 epsilon 값
successful_targeted_attacks = extract_targeted_attacks(model, testloader, epsilon, target_label)

def plot_targeted_attacks(targeted_attacks):
    if not targeted_attacks:
        print("No successful targeted adversarial attacks to display.")
        return

    num_attacks = len(targeted_attacks)
    fig, axs = plt.subplots(2, num_attacks, figsize=(2 * num_attacks, 4))

    if num_attacks == 1:
        axs = [axs]

    for i, (perturbed, noise, original_label, pred_label) in enumerate(targeted_attacks):
        ax_img = axs[0][i]
        ax_noise = axs[1][i]

        # 데이터 범위를 [0, 1]로 클리핑
        perturbed_clipped = np.clip(np.transpose(perturbed, (1, 2, 0)), 0, 1)

        # 노이즈 정규화
        noise_normalized = (noise - noise.min()) / (noise.max() - noise.min())

        # 이미지 크기를 축소해서 표시
        ax_img.imshow(perturbed_clipped)
        pred_label_str = classes[pred_label] if 0 <= pred_label < len(classes) else "Unknown"
        original_label_str = classes[original_label] if 0 <= original_label < len(classes) else "Unknown"
        ax_img.set_title(f'Original: {original_label_str}\nPred: {pred_label_str}')
        ax_img.axis('off')

        # 이미지 크기를 축소해서 표시
        ax_noise.imshow(np.transpose(noise_normalized, (1, 2, 0)))
        ax_noise.set_title('Noise (Normalized)')
        ax_noise.axis('off')

    plt.tight_layout()
    plt.show()

plot_targeted_attacks(successful_targeted_attacks)

출력 결과

fgsm