日常生活中经常遇到求解数组中是否存在和为0的三个数,即3-sum问题,为此本文介绍一些比较实用的方法,并在暴力计算的基础上对算法逐步改进,以达到最优的算法。
首先介绍一下2-sum问题,3-sum问题其实就是2-sum问题的一个扩展而已。对于2-sum问题,最一般的想法就是使用两层循环直接枚举数组中的所有元素对,直到找到和为0的元素对为止。
int sum_2()
{
int res = 0;
int n = data.size();
for(int i=0; i<n; i++)
{
for(int j=i+1; j<n; j++)
{
if(data[i] + data[j] == 0)
{
res ++;
}
}
}
return res;
}上述算法的由于包含了两层循环,因此时间复杂度为O(N^2)。
观察发现,上述算法时间主要花费在数据比对,为此可以考虑使用二分查找来减少数据比对时间,要想使用二分查找,首先应该对数据进行排序,在此使用归并排序对数组进行升序排列。排序所花时间为O(NlogN),排序之后数据查找只需要O(logN)的时间,但是总共需要查找N此,为此改进后算法的时间复杂度为O(NlogN)。
int cal_sum_2()
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
int j = binary_search(-data[i]);
if(j > i)
res++;
}
return res;
}观察上述算法发现,我们在比对的过程中还是存在了一些冗余。因为排列后的数据是从最小的数开始匹配的,我们只需计算其与最后的数据的和是否为0即可,如果大于0,则说明不存在与最小数匹配的数,此时将用较小的数来替代最大的数,反之则选用较大的数替代最小的数,如此反复,只需要扫描一遍数组即可得到所有符合条件的元素对。此算法所用的时间主要还是数组排序的时间,即O(NlogN)。int cal_sum_2_update()
{
int res = 0;
for(int i=0,j=data.size()-1; i<j; )
{
if(data[i] + data[j] > 0)
j--;
else if(data[i] + data[j] < 0)
i++;
else
{
res++;
j--;
i++;
}
}
return res;
}上述2-sum的解题思路适用于3-sum及4-sum问题,如求解a+b+c=0,可将其转换为求解a+b=-c,此就为2-sum问题。为此将2-sum,3-sum,4-sum的求解方法以及相应的优化方法实现在如下所示的sum类中。
sum类定义
#ifndef SUM_H
#define SUM_H
#include <vector>
using std::vector;
class sum
{
private:
vector<int> data;
public:
sum(){};
sum(const vector<int>& a);
~sum(){};
int cal_sum_2() const;
int cal_sum_3() const;
int cal_sum_4() const;
int cal_sum_2_update() const;
int cal_sum_3_update() const;
int cal_sum_3_update2() const;
int cal_sum_4_update() const;
void sort(int low, int high);
void print() const;
friend int find(const sum& s, int target);
};
#endifsum类实现#include "Sum.h"
#include <iostream>
using namespace std;
sum::sum(const vector<int>& a)
{
data = a;
}
void sum::sort(int low, int high)
{
if(low >= high)
return;
int mid = (low+high)/2;
sort(low,mid);
sort(mid+1,high);
vector<int> temp;
int l = low;
int h = mid+1;
while(l<=mid && h <=high)
{
if(data[l] > data[h])
temp.push_back(data[h++]);
else
temp.push_back(data[l++]);
}
while(l<=mid)
temp.push_back(data[l++]);
while(h<=high)
temp.push_back(data[h++]);
for(int i=low; i<=high; i++)
{
data[i] = temp[i-low];
}
}
void sum::print() const
{
for(int i=0; i<data.size(); i++)
{
cout<<data[i]<<" ";
}
cout<<endl;
}
int find(const sum& s, int target)
{
int low = 0;
int high = s.data.size()-1;
while(low <= high)
{
int mid = (low + high)/2;
if(s.data[mid] < target)
{
low = mid+1;
}
else if(s.data[mid] > target)
{
high = mid - 1;
}
else
{
return mid;
}
}
return -1;
}
int sum::cal_sum_2() const
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
int j = find(*this, -data[i]);
if(j > i)
res++;
}
return res;
}
int sum::cal_sum_3() const
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
for(int j=i+1; j<data.size(); j++)
{
for(int p=j+1;p<data.size();p++)
{
if(data[i] + data[j] + data[p] == 0)
res++;
}
}
}
return res;
}
int sum::cal_sum_4() const
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
for(int j=i+1; j<data.size(); j++)
{
for(int p=j+1; p<data.size(); p++)
{
for(int q=p+1; q<data.size(); q++)
{
if(data[i]+data[j]+data[p]+data[q] == 0)
res++;
}
}
}
}
return res;
}
int sum::cal_sum_2_update() const
{
int res = 0;
for(int i=0,j=data.size()-1; i<j; )
{
if(data[i] + data[j] > 0)
j--;
else if(data[i] + data[j] < 0)
i++;
else
{
res++;
j--;
i++;
}
}
return res;
}
int sum::cal_sum_3_update() const
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
for(int j=i+1; j<data.size(); j++)
{
if(find(*this, -data[i] - data[j]) > j)
res ++;
}
}
return res;
}
int sum::cal_sum_3_update2() const
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
int j=i+1;
int p=data.size()-1;
while(j<p)
{
if (data[j] + data[p] < -data[i])
j++;
else if(data[j] + data[p] > -data[i])
p--;
else
{
res++;
j++;
p--;
}
}
}
return res;
}
int sum::cal_sum_4_update() const
{
int res = 0;
for(int i=0; i<data.size(); i++)
{
for(int j=i+1; j<data.size(); j++)
{
for(int p=j+1; p<data.size(); p++)
{
if(find(*this, -data[i]-data[j]-data[p])>p)
res++;
}
}
}
return res;
}测试代码#include "Sum.h"
#include <iostream>
#include <fstream>
#include <vector>
using namespace std;
void main()
{
ifstream in("1Kints.txt");
vector<int> a;
while(!in.eof())
{
int temp;
in>>temp;
a.push_back(temp);
}
sum s(a);
s.sort(0,a.size()-1);
s.print();
cout<<"s.cal_sum_2() = "<<s.cal_sum_2()<<endl;
cout<<"s.cal_sum_2_update() = "<<s.cal_sum_2_update()<<endl;
cout<<"s.cal_sum_3() = "<<s.cal_sum_3()<<endl;
cout<<"s.cal_sum_3_update() = "<<s.cal_sum_3_update()<<endl;
cout<<"s.cal_sum_3_update()2 = "<<s.cal_sum_3_update2()<<endl;
cout<<"s.cal_sum_4() = "<<s.cal_sum_4()<<endl;
cout<<"s.cal_sum_4_update() = "<<s.cal_sum_4_update()<<endl;
}上述算法设计思路希望对你在今后学习算法的过程中有所帮助。
版权声明:本文为shaya118原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。