세그먼트 트리에 대한 글은 이미 많습니다.
심지어 다들 구체적으로 설명해주고 계십니다.
그런데 세그먼트 트리를 단지 구간합 트리라고 알고 계시는 분들이 생각보다 많이 계신 이유로 오해를 정정하고자 포스팅해보겠습니다.
세그먼트 트리가 없다면?
세그먼트 트리(구간 트리)는 주어진 쿼리에 대해서 빠르게 응답하기 위한 자료구조입니다.
배열 A가 주어져있고 A의 start~end 구간까지의 합을 구하려고 합니다.
for문으로 answer+=A[i]를 돌리면 되겠죠.
그런데 만약 이 구간내 값이 변동이 된다면 어떨까요??
총 M번 반복한다고 가정했을때,
수정 연산(O(1))+ 합산 연산 = O(NM)의 시간복잡도가 발생합니다.
세그먼트 트리 도입
여기에 세그먼트 트리를 사용한다면 어떻게 변할까요??
수정 연산 = 합산 연산 = O(log N)
으로 O(MlogN)이 됩니다.
수정 연산에서 시간 복잡도가 조금 상승했지만~ 총합에서는 꽤나 괜찮죠.
세그먼트 트리는 구간 트리로 각 노드에 저장되는 값이 조금 특이해요.
각 구간별 특정 연산에 대한 값을 저장하게 됩니다.
예제에 같은 합산 연산인 경우 맨 밑. 리프 노드는 그냥 자기 자신의 값을 가지고 부모 노드는 left+right의 값을 가집니다.
보통은 이제 이렇게 구간합 예시를 들고 세그먼트 트리 구현을 위한 설명을 해주시죠.
딱 거기까지만 보셔서 그런지 오해하시는 분들이 계셨습니다.
저 A+B는 min(A,B)로 대체될 수 있고, max도 가능하죠.
중요한건 구간'합'이 아니라 '구간' 트리라는 부분입니다.
소스코드
자.. 세그먼트 트리 구현은 음.. 두 가지로 나뉩니다.
bottom up이냐 top down이냐에 따라서 바뀌는데 bottom up이 접근성과 성능이 좋다고는 해요.
대신 top down 만큼의 확장성은 덜하다고 많이들 말씀하시네요.
+ top down은 일반적으로 재귀를 사용하므로 성능이 떨어진다고 합니다.
제 개인적으로는 top down 방식이 익숙하니 이 방식대로 해볼게요.
초기화
직접 클래스로 작성하셔도 되지만 간단하게 배열로 작성해도 됩니다.(어떤 노드의 번호가 x일때, left 는 2*x, right는 2*x+1이 되므로 인덱싱도 쉬워서 배열도 가능합니다.)
세그먼트 트리를 배열로 구현하기 전에 사이즈를 미리 정해놓는 편이 좋죠.
효율적으로 결정하려면 Crocus 님께서 언급해주신 내용을 그대로 따라해야합니다.
N = 12일 때의 세그먼트 트리의 전체 크기(배열 사이즈 정하기)를 구하기 위해서는
2^k로 12보다 바로 큰 값을 만들 수 있는 k를 찾아야한다. 즉, k는 4이다.
그리고 난 뒤 2^k를 하면 16이 되고 16에 *2를 하면 우리가 원하는 세그먼트 트리의 크기를 구할 수 있다.
int h = (int)ceil(log2(n));
int tree_size = (1 << (h+1));
출처: https://www.crocus.co.kr/648 [Crocus]
이 과정이 번거롭다면 그냥 4*N으로 트리 사이즈를 정해주세요.
def init(start, end, index):
if start==end:
tree[index] = array[start]
return tree[index]
mid = (start+end)//2
tree[index] = init(start,mid,index*2)+init(mid+1,end,index*2+1)
return tree[index]
세그먼트 트리를 맨 위 루트 노드부터 left,right를 분리시켜주는 과정입니다.
언제까지? 리프 노드일때까지 이 과정을 반복해주세요.
즉 루트 노드가 array의 0~9번 인덱스까지의 내용을 가지고 있으면
0~4까지는 left로 5~9까지는 right의 구간으로 정합니다.
세그먼트 트리(구간합ver)의 뼈대는 완성됐습니다.
Update
주어진 배열에(세그먼트 트리 아님) 값의 변화가 일어났다면 트리에 반영을 해줘야죠.
def update(start, end, index,k,diff):#k는 바꾸고자 하는 인덱스 번호,diff는 바꾸는 값
if (not(start<=k<=end)):
return
tree[index] += diff
if start!=end:
mid = (start+end)//2
update(start,mid,index*2,k,diff)
update(mid+1,end, index * 2+1, k, diff)
top-down은 이 부분 때문에 bottom-up 방식보다 번거롭다고 하나봐요.
일단 범위를 넘어가면 당연히 함수를 실행할 필요가 없죠.
다음에는 세그 트리에 값을 반영해줘야 합니다.
root 노드부터 시작해서 k인덱스랑 관련된 모든 값들은 전부 계산을 다시 해줘야죠.
이 파트에서 재귀가 또 이용됩니다.
구간합 기준으로 diff변수에 대한 언급이 필요하겠네요.
k인덱스의 값 5가 15로 바꼈다고 해봅시다.
그러면 차이는 10이고 딱 10만큼만 추가로 더해주면 갱신이 끝납니다.
수가 더 작아졌어도 반영이 되겠네요.
query
def query(start,end,index,left,right):
#범위를 벗어나는 경우
if left>end or right<start:
return 0
#범위 내에 있는 경우
if left <= start and end<=right:
return tree[index]
#범위 재탐색
mid = (start+end)//2
return query(start,mid,index*2,left,right)+query(mid+1,end,index*2+1,left,right)
다른 예제에서도 활용하려고 메서드 이름을 query로 해뒀는데 지금은 그냥 summation 하는 과정입니다.
이 부분은 사실 그림을 그려보고 이해해보는 과정이 필요합니다.
제가 직접 다루기에는.. 그림이...ㅠㅠㅠ
그래서 난 감이 안온다 하시는 분들은
이 글을 참고해주세요.
bottom-up 방식
혹시 몰라서 bottom-up 방식도 추가하겠습니다.
n = len(arr)
def init(arr) :
# 리프 노드를 트리에 삽입
for i in range(n) :
tree[n + i] = arr[i]
# 부모 노드 계산(리프 -> 부모 ->루트까지 거꾸로 올라감)
for i in range(n - 1, 0, -1):
tree[i] = tree[2*i] + tree[2*i+1]
'''
같은 내용
for i in range(n - 1, 0, -1) :
tree[i] = tree[i << 1] + tree[i << 1 | 1]
'''
def update(k, value) :
# k인덱스에 value 삽입
tree[k + n] = value
# 부모 노드에 반영
i = k + n
while i > 1:
tree[i//2] = tree[i] + tree[i+1]
i =i//2
'''
while i > 1 :
tree[i >> 1] = tree[i] + tree[i ^ 1]
i >>= 1
'''
def query(l, r) :
answer = 0
l += n
r += n
while l < r:
if ((l & 1)>0):#
answer += tree[l]
l += 1
if ((r & 1)>0):
r -= 1
answer += tree[r]
l =l// 2
r =r// 2
'''
while l < r :
if (l & 1) :
answer += tree[l]
l += 1
if (r & 1) :
r -= 1
answer += tree[r]
l >>= 1
r >>= 1
'''
return answer
코드 묶음이 주석처리가 된 것은 비트 연산을 이용해서 조금 더 효율적으로 세그먼트 트리를 작성하는 방식이라고 해서 가져왔습니다.
const int N = 1e5; // limit for array size
int n; // array size
int t[2 * N];
void build() { // build the tree
for (int i = n - 1; i > 0; --i) t[i] = t[i<<1] + t[i<<1|1];
}
void modify(int p, int value) { // set value at position p
for (t[p += n] = value; p > 1; p >>= 1) t[p>>1] = t[p] + t[p^1];
}
int query(int l, int r) { // sum on interval [l, r)
int res = 0;
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l&1) res += t[l++];
if (r&1) res += t[--r];
}
return res;
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", t + n + i);
build();
modify(0, 1);
printf("%d\n", query(3, 11));
return 0;
}
이 C/C++ 소스 코드의 출처는 코드포스로 저렇게 라인수도 줄이고 효율도 업시킬 수 있습니다.
이 글이 다른 내용들도 다루고 있어서 대회준비중이신 분들은 보셔도 좋을듯합니다.
codeforces.com/blog/entry/18051
결국 저도 구간합 위주로만 설명을 드렸는데요.
아까 언급했듯이 '구간'트리이므로 다른 연산들로 대체 가능합니다.
합대신에 곱하기를 넣으면 구간곱 트리가 되고 틀 하나로 여러 문제에 가져다가 쓸 수 있어요.
제가 전하려고 했던 포인트가 제대로 전달이 되었으면 좋겠네요. 감사합니다.
참고 및 추천 링크(본문에 삽입된 링크 제외)
각 연산 단위로 시간 복잡도 언급하십니다.
sungwookyoo.github.io/algorithms/SegmentTree/
세그먼트 트리 다음으로 익혀두면 좋은 펜윅 트리랑 lazy propagation
'알고리즘 > 이론' 카테고리의 다른 글
[JAVA] 이진트리 탐색 dfs(중위 탐색)와 bfs 구현하기 (0) | 2021.04.20 |
---|---|
크루스칼 알고리즘 (0) | 2019.05.26 |
그래프 표현 방식과 탐색 방법 (0) | 2019.05.19 |
Dynamic Programming (0) | 2019.05.04 |
Hash 함수 (2) | 2019.04.30 |