你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

LeetCode刷题笔记 排序算法 归并排序

2021/12/28 11:45:07

归并排序简介

算法思想:

​ 归并排序的核心思想是采用分治策略,将整个数组的排序任务分类为两个子问题,前一半排序和后一半排序,然后整合两个有序部分完成整体排序。即把数组分为若干个子序列,直到单个元素组成一个序列,然后将各阶段得到的序列组合在一起得到最终完整排序序列。

​ 归并排序任务可以如下分治完成:

1. 把前一半排序
2. 把后一半排序
3. 把两半归并到一个新的有序数组,然后再拷贝回原数组,排序完成。

执行样例:

输入:[29,10,14,37,14,25,10]

在这里插入图片描述

算法实现:

// 将数组 a 的局部 a[s,m] 和 a[m+1,e] 合并到 tmp, 并保证 tmp 有序,然后再拷贝回 a[s,m]
void merge(vector<int>& arr, int start, int mid, int end, vector<int> tmp){
    int pTmp = 0;
    int pLeft = start; int pRight = mid+1;
    while(pLeft<=mid&&pRight<=end){
        if(arr[pLeft] < arr[pRight]){
            tmp[pTmp++] = arr[pLeft++];
        }else{
            tmp[pTmp++] = arr[pRight++];
        }
    }
    while(pLeft<=mid){
        tmp[pTmp++] = arr[pLeft++];
    }
    while (pRight<=end)
    {
        tmp[pTmp++] = arr[pRight++];
    }
    for(int i=0;i<pTmp;i++){
        arr[start+i] = tmp[i];
    }
}

// 归并排序递归调用,先排前半部分,在排后半部分,最后将两部分结果合并
void mergeSort(vector<int>& arr, int start, int end, vector<int> tmp){
    if(start < end){
        int mid = start + (end-start)/2;
        mergeSort(arr,start,mid,tmp);
        mergeSort(arr,mid+1,end,tmp);
        merge(arr,start,mid,end,tmp);
    }
}

493 翻转对

给定一个数组 nums ,如果 i < jnums[i] > 2*nums[j] 就将 (i,j) 称作一个翻转对,返回给定数组中翻转对的数量。

输入一个数组,输出一个整数表示数组中翻转对的个数

输入: [2,4,3,5,1]
输出: 3
解释: (4,1),(3,1),(5,1)三个翻转数对

解析:

​ 本题的实质是求数组中的逆序数,但是本题中对逆序数有新的定义。

​ 我们基于归并排序算法采用分治策略:将排列分为两部分,先求左半部分的翻转对,然后求右半部分的翻转对;最后算两部分之间存在的翻转对,时间复杂度 O(nlogn)

​ 左半边和右半边都是排好序的。例如,都是从大到小排序的,左右半边只需要从头到尾各扫一遍,就可以找出由两边各取一个数构成的翻转对数。

​ 下面给出了一种较为容易理解的实现方法,在归并排序代码的基础上做了很小的修改,就是当左侧元素大于右侧时开始寻找翻转对。但是,本题的数据规模最大可达到50000,如果使用这种简单循环遍历将导致超时。

class Solution {
public:
    void mergeAndCount(vector<int>& nums, int start, int mid, int end, vector<int> tmp, int& count){
        if(start>end) return;
        int pTmp = 0;
        int pLeft = start, pRight = mid+1;
        while(pLeft<=mid && pRight<=end){
            if(nums[pLeft] < nums[pRight]){
                tmp[pTmp++] = nums[pLeft++];
            }else{
                // 这种循环遍历方法将导致超时
                for(int i=pLeft;i<=mid;++i){
                    if(nums[i] > (long long)2*nums[pRight]){
                        ++count;
                    }
                }
                tmp[pTmp++] = nums[pRight++];
            }
        }
        while(pLeft<=mid){
            tmp[pTmp++] = nums[pLeft++];
        }
        while(pRight<=end){
            tmp[pTmp++] = nums[pRight++];
        }
        for(int i=0;i<pTmp;++i){
            nums[start+i] = tmp[i];
        }
    }

    void mergeSort(vector<int>& nums, int start, int end, vector<int> tmp, int& count){
        if(start<end){
            int mid = start + (end-start)/2;
            mergeSort(nums,start,mid,tmp,count);
            mergeSort(nums,mid+1,end,tmp,count);
            mergeAndCount(nums,start,mid,end,tmp,count);
        }
    }

    int reversePairs(vector<int>& nums) {
        int ans = 0;
        vector<int> tmp(nums.size());
        mergeSort(nums,0,nums.size()-1,tmp,ans);
        return ans;
    }
};

​ 为了降低时间复杂度,我们采用如下策略:如果左半边 A1 和右半边 A2 都是排好序的,我们就可以在线性时间内解决这个问题了。当然,也可以用二分查找来解决,但是时间复杂度就是线性对数的了。

  • 初始化两个指针i,j分别指向A1,A2的头部

  • 如果 A1[i] > 2*A2[j] ,那么A1[i]及其后面的所有元素都符合要求,更新答案并后移j,否则,后移i

  • 最后,合并A1, A2 以备解决后面更大的子问题使用,并返回前结果

class Solution {
public:
    int find_reversed_pairs(vector<int>& nums, int start, int end){
        int res = 0,mid = start + (end-start)/2;
        int i = start,j = mid+1;
        for(;i <= mid;i++){
            while(j <= end && (long)nums[i] > 2*(long)nums[j]) {
                res += mid - i + 1;
                j++;
            }
        }
        return res;
    }
    
    int merge_sort(vector<int>& nums, int start, int end, vector<int> tmp){
        if(start >= end) return 0;
        int mid = start + (end-start) / 2;
        
        int res = merge_sort(nums,tmp,start,mid) + 
                  merge_sort(nums,tmp,mid+1,end) + 
                  find_reversed_pairs(nums,start,end);
        
        int i = start,j = mid+1,ind = start;
        
        while(i <= mid && j <= end){
            if(nums[i] <= nums[j]) tmp[ind++] = nums[i++];
            else tmp[ind++] = nums[j++];
        }
        while(i <= mid) tmp[ind++] = nums[i++];
        while(j <= end) tmp[ind++] = nums[j++];
        
        for(int ind = start;ind <= end;ind++) nums[ind] = tmp[ind];
    
        return res;
    }
    
    int reversePairs(vector<int>& nums) {
        if(nums.empty()) return 0;
        vector<int>& tmp(nums.size());
        return merge_sort(nums,0,nums.size()-1,tmp);
    }
};

148 排序链表

给定链表的头结点 head ,请将其按 升序 排列并返回 排序后的链表 ,要求在 O(nlogn)的时间复杂度内完成。

输入一个链表,输出一个按升序排列的链表。

输入:head = [4,2,1,3]
输出:[1,2,3,4]

解析:

​ 在完成本题之前可以先尝试 21 合并两个有序链表、23 合并K个升序链表 两题,也许对你解决本题有帮助。

​ 本题我们可以使用归并排序的思想完成。采用分治策略,先将链表划分为两个部分,然后在每个部分中进行排序。所以,我们需要完成如下两个任务:

  • 寻找链表中点划分链表
  • 排序局部链表,并合并结果

​ 第一个任务我们可以使用快慢指针完成,快指针一次走两步,慢指针一次走一步,这样当快指针到达链表末尾时,慢指针刚好指向链表中点。然后,我们递归地采用上诉方法将链表不断划分为更小的子问题,直到不可再分。这里需要注意一个常犯的错误:在寻找中点时不能直接以 nullptr 判断快指针是否达到终点,而是要将其与链表尾节点 tail 进行比较。

​ 第二个任务升序合并链表:分割完成之后,迭代合并两部分链表,将链表按照升序合并,最终形成完整的升序链表。

class Solution {
public:
	// 合并两个链表
    ListNode* merge(ListNode *l1, ListNode *l2){
        if(!l1 || !l2) return l1?l1:l2;
        ListNode dummy;
        ListNode *tail = &dummy;
        ListNode *p1 = l1, *p2 = l2;
        while(p1 && p2){
            if(p1->val < p2->val){
                tail->next = p1;
                p1 = p1->next; 
            }else{
                tail->next = p2;
                p2 = p2->next;
            }
            tail = tail->next;
        }
        tail->next = p1?p1:p2;
        tail = nullptr;
        return dummy.next;
    }
	
    // 寻找链表终点并递归划分链表
    ListNode* mergeSort(ListNode* head, ListNode* tail){
        if(!head) return nullptr;
        if(head->next == tail){
            head->next = nullptr;
            return head;
        }

        ListNode *slow = head, *fast = head;
        // // 错误示范,会导致越界
        // while(fast&&fast->next){
        //     fast = fast->next->next;
        //     slow = slow->next;
        // }
        while(fast != tail){
            slow = slow->next;
            fast = fast->next;
            if(fast != tail){
                fast = fast->next;
            }
        }
        ListNode* mid = slow;
        return merge(mergeSort(head,mid),mergeSort(mid,tail));
    }

    ListNode* sortList(ListNode* head) {
        return mergeSort(head,nullptr);
    }
};

参考资料

LeetCode 101:和你一起轻松刷题(C++) 第 5 章 千奇百怪的排序算法

算法分析与设计 八大排序算法