컴퓨터공학/알고리즘2

[알고리즘2] Percolation with Union Find(WQU) - 실습

NIMHO 2022. 10. 11. 23:15
728x90

복습하기 위해 학부 수업 내용을 필기한 내용입니다.
이해를 제대로 하지 못하고 정리한 경우 틀린 내용이 있을 수 있습니다.
그러한 부분에 대해서는 알려주시면 정말 감사하겠습니다.

Percolate

  • N × N 개의 객체가 격자를 이룬다 (그림 참조)
  • 각 객체는 두 상태(열림, 닫힘) 중 하나를 가질 수 있으며
  • 가장 윗줄이 가장 아랫줄에 연결되었다면 (열린 격자 통해 이동 가능) 이 격자는 percolate 한다고 한다.

 

simulation 방법 개요

  • N × N 개의 객체를 닫힌 상태로 초기화
  • 닫힌 객체 중 하나를 (임의로 선정해) 열린 상태로 바꾸고 percolate 하는지 확인
  • 위 -> 아래로 percolate 할 때까지 반복
  • percolate 할 때 열려 있는 객체의 비율(=열린 객체 수 / (N × N))을 p의 예측치로 사용
  • 위와 같은 시뮬레이션을 여러 회 반복해 p의 평균 혹은 신뢰 구간 구하기
728x90

simulation 방법

  • N을 입력으로 받고
  • N × N 개의 객체를 닫힌 상태로 초기화
  • 닫힌 객체 중 하나를 (임의로 선정해) 열린 상태로 바꾸고, 위 -> 아래로 percolate 할 때까지 반복
    • 인접한 열린 객체는 서로 연결(union)되어 같은 connected component에 속한다고 보면 됨
    • 한 객체를 열 때마다 인접한 4곳(up, down, left, right) 열렸는지 확인해 열린 객체와 모두 연결

 

프로그램 입출력 조건

  • 정수 n과 t를 입력으로 받는 함수 정의 (1≤n≤200, 2≤t≤105 )
    • def simulate(n, t):
  • 위 함수는 n × n 격자에 대해 t회 시뮬레이션 반복한 후
  • Percolate 할 때 열린 객체 비율의 평균, 표준편차, 95% 신뢰구간을 아래 예제와 같은 형식으로 출력
    • t회 예측치가 x1, x2, …, xt라 할 때,
    • 평균 = (x1+x2+…+xt) / t     # statistics.mean() 사용해 계산
    • 표준편차^2 = {(x1-평균)^2 + (x2-평균)^2 + … (xT-평균)^2} / (t-1)     # statistics.stdev() 사용해 계산
    • 95% 신뢰구간 = [평균 – 1.96 * 표준편차 / √t , 평균 + 1.96 * 표준편차 / √t ]     # math.sqrt() 사용해 제곱근 계산
    • 위 값은 모두 소수점 아래 10자리로 출력 (format string “.10f”에 해당)
  • 이 중 평균과 표준편차를 반환
  • Weighted Quick Union 방법 사용해 구현해야 하며
  • 성능: 이어지는 각 예제에 대해 10초 이내에 출력이 나오면 됨
import random
import statistics
import math


def simulate(N, T):
    def root(i):
        while i != ids[i]:
            i = ids[i]
        return i

    def connected(p, q):
        return root(p) == root(q)

    def union(p, q):
        id1, id2 = root(p), root(q)
        if id1 == id2:
            return
        if size[id1] <= size[id2]:
            ids[id1] = id2
            size[id2] += size[id1]
        else:
            ids[id2] = id1
            size[id1] += size[id2]

    number = []
    first = N * N
    last = N * N + 1
    for c in range(T):
        cnt = 0
        idx = []
        ids, size, oc = [], [], []
        for i in range(N * N):
            idx.append(i)
        random.shuffle(idx)

        for i in range(N * N + 2):
            ids.append(i)
            size.append(1)
            oc.append(0)

        while True:
            a = idx.pop()
            if a % N != 0 and oc[a - 1] == 1:
                union(a, a - 1)
            if (a + 1) % N != 0 and oc[a + 1] == 1:
                union(a, a + 1)
            if a - N >= 0 and oc[a - N] == 1:
                union(a, a - N)
            if a + N < N * N and oc[a + N] == 1:
                union(a, a + N)
            oc[a] = 1
            cnt += 1
            if 0 <= a < N:
                union(first, a)
            if N * N - N <= a < N * N:
                union(last, a)

            if connected(first, last):
                number.append(cnt / (N * N))
                break
    mean = statistics.mean(number)
    stdev = statistics.stdev(number)
    a = mean - 1.96 * stdev / math.sqrt(T)
    b = mean + 1.96 * stdev / math.sqrt(T)
    print("mean                    = {:.10f}".format(mean))
    print("stdev                   = {:.10f}".format(stdev))
    print("95% confidence interval = [{:.10f}, {:.10f}]".format(a, b))
    return mean, stdev
728x90