Algorithm
백준
Python
인구이동

백준 16234 인구 이동

https://www.acmicpc.net/problem/16234 (opens in a new tab)

내 풀이

내가 return [avg, visited]를 해준 이유는 매일마다 연합을 미리 묶은 뒤 연합이 다 묶인 후에 일제히 값을 변환해주기 위해서였다. 그렇게하지 않으면 매일 중복해서 변하는 값이 있을 수 있다고 생각했기 때문이다. 그리고 이게 얼마나 멍청한 생각이었는지는 이 글 (opens in a new tab)(근데 문제는 이 풀이도 틀림)을 보고 알았다. 애초에 매일 visited를 체크해주면 연합으로 묶일 일이 없는데 쓸데없이 시간복잡도와 공간복잡도를 낭비하는 코드를 짜고 있었다.

from collections import deque
import sys
 
input = sys.stdin.readline
 
N,L,R = map(int, input().split())
 
graph = []
 
for _ in range(N):
  tmp = list(map(int,input().split()))
  graph.append(tmp)
 
def localized(start) :
  global graph
 
  q = deque()
  visited = []
  local_values = []
  q.append(start)
  visited.append(start)
  _d = [[-1,0], [1,0], [0,-1], [0, 1]]
 
  while q :
    ly, lx = q.popleft()
    local_values.append(graph[ly][lx])
    for d in _d :
      cy, cx = ly - d[0], lx - d[1]
      if -1 < cy < N and -1 < cx < N :
        if L <= abs(graph[ly][lx] - graph[cy][cx]) <= R :
          if [cy, cx] not in visited :
            visited.append([cy, cx])
            q.append([cy, cx])
  avg = sum(local_values) // len(local_values)
  if len(local_values) > 1 :
    return [avg, visited]
  return None
 
cnt = 0
keep_loop = True
while keep_loop :
  visited = []
  results = []
  keep_loop = False
  for y in range(N) :
    for x in range(N) :
      if [y,x] not in visited :
        result = localized([y,x])
        if result :
          keep_loop = True
          results.append(result)
          for v in result[1] :
            visited.append(v)
  for r in results :
    avg, visited = r
    for v in visited :
      graph[v[0]][v[1]] = avg
  if keep_loop :
    cnt += 1
print(cnt)

visited를 이용해서 연합 중복 방지

아래 코드를 보면 매일 visited를 만들고 연합으로 묶인 것들을 visited에 추가하고 연합으로 묶인 것들은 그 때 그 때 바로 처리해줌으로써 중복되는 연합을 피해주고 있다.

import collections
import sys
input = sys.stdin.readline
 
n, l, r = list(map(int, input().split()))
graph = [list(map(int, input().split())) for _ in range(n)]
dx = [0, 0, 1, -1]
dy = [1, -1, 0, 0]
 
cnt = 0
while True:
    visited = [[0] * n for _ in range(n)]
    flag = True
 
    # 인구 이동 한 세트
    for m in range(n):
        for v in range(n):
            if visited[m][v] == 0:
                queue = collections.deque([(m, v)])
                visited[m][v] = 1
 
                # (m, v) 기준으로 연합 구하기
                temp = [(m, v)]
                while queue:
                    i, j = queue.pop()
                    for z in range(4):
                        nx = i + dx[z]
                        ny = j + dy[z]
                        if 0 <= nx < n and 0 <= ny < n and visited[nx][ny] == 0:
                            if l <= abs(graph[i][j] - graph[nx][ny]) <= r:
                                queue.appendleft((nx, ny))
                                visited[nx][ny] = 1
                                temp.append((nx, ny))
 
                # 연합에 대해서 그래프 값 갱신하기
                if len(temp) > 1:
                    flag = False
                    meanv = sum([graph[i][j] for i, j in temp]) // len(temp)
                    for i, j in temp:
                        graph[i][j] = meanv
 
    if flag:  # 더 이상 인구 이동이 없으면
        break
    cnt += 1
print(cnt)
 

정답 풀이

import sys
from collections import deque
N, L, R = map(int,input().split())
country = []
dx = [1,-1,0,0]
dy = [0,0,1,-1]
population_flag = True
result = 0
 
for _ in range(N):
    country.append([int(x) for x in sys.stdin.readline().rstrip().split()])
 
def bfs(x,y):
    global population_flag
    union_list = []
    union_person = 0
    deq = deque()
 
    visited[x][y] = True
    union_list.append([x,y])
    union_person += country[x][y]
    deq.append([x,y])
 
    while deq:
        a,b = deq.popleft()
        for i in range(4):
            nx, ny = a + dx[i], b + dy[i]
            if 0 <= nx < N and 0 <= ny < N and not visited[nx][ny]:
                if L <= abs(country[a][b] - country[nx][ny]) <= R:
                    visited[nx][ny] = True
                    deq.append([nx,ny])
                    union_list.append([nx,ny])
                    union_person += country[nx][ny]
 
    union_len = len(union_list)
 
    if union_len >= 2:
        for u in union_list:
            country[u[0]][u[1]] = union_person // union_len
        population_flag = True
 
def check(x, y):
    for i in range(4):
        nx, ny = x + dx[i], y + dy[i]
        if 0 <= nx < N and 0 <= ny < N:
            if L <= abs(country[x][y] - country[nx][ny]) <= R:
                return 1
    return 0
 
while population_flag:
    population_flag = False
    visited = [[False] * N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            if not visited[i][j] and check(i,j):
                bfs(i,j)
 
    if population_flag:
        result += 1
 
print(result)

Reference