| input | output |
|---|---|
8 8 7 1 6 4 3 5 2 |
3 5 1 4 5 8 7 6 3 2 |
6 3 5 1 2 6 4 |
3 3 3 5 6 1 2 4 |
6 3 5 2 1 6 4 |
Fail |
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <vector>
#include <math.h>
#include <algorithm>
using namespace std;
#define ls 2*i
#define rs 2*i+1
#define up(i,x,y) for(i=x;i<=y;i++)
#define down(i,x,y) for(i=x;i>=y;i--)
#define mem(a,x) memset(a,x,sizeof(a))
#define w(a) while(a)
#define LL long long
const double pi = acos(-1.0);
#define Len 100005
#define mod 1000000007
const int INF = 0x3f3f3f3f;
//0为递增,1为递减
struct node
{
int x,y;
} s[Len][2];//记录路径
int n;
int a[Len];
int dp[Len][2];//dp[i][j],保存的就是看在j状态下,第i个位置所在的连续递增或递减序列,最前面那个比这个连续序列第一个大或者小的数
vector<int> ans[2];
void upp(int x,int y,int i,int j,int m)
{
if(dp[x][y]>m)
{
dp[x][y] = m;
s[x][y].x = i;
s[x][y].y = j;
}
}
void downn(int x,int y,int i,int j,int m)
{
if(dp[x][y]<m)
{
dp[x][y] = m;
s[x][y].x = i;
s[x][y].y = j;
}
}
//全递增
int set1()
{
int i,j;
dp[1][0] = dp[1][1] = -1;
up(i,1,n-1)
{
dp[i+1][0]=dp[i+1][1] = INF;
up(j,0,1)
{
if(dp[i][j]<INF)
{
if(a[i-1]<a[i])//数组的下标对应dp的下标+1,也就是说i对应i+1
upp(i+1,j,i,j,dp[i][j]);
if(dp[i][j]<a[i])//如果不满足j条件下的递增,那么就看能否保存到j^1的条件下
upp(i+1,j^1,i,j,a[i-1]);
}
}
}
return dp[n][0]+dp[n][1]<INF;
}
//全递减
int set2()
{
int i,j;
dp[1][0] = dp[1][1] = INF;
up(i,1,n-1)
{
dp[i+1][0]=dp[i+1][1] = -1;
up(j,0,1)
{
if(dp[i][j]>=0)
{
if(a[i-1]>a[i])
downn(i+1,j,i,j,dp[i][j]);
if(dp[i][j]>a[i])
downn(i+1,j^1,i,j,a[i-1]);
}
}
}
return dp[n][0]+dp[n][1]>0;
}
//一个递增,一个递减
int set3()
{
int i,j;
dp[1][0] = INF;
dp[1][1] = -1;
up(i,1,n-1)
{
dp[i+1][0] = -1;
dp[i+1][1] = INF;
if(dp[i][0]>0)
{
if(a[i-1]<a[i])
downn(i+1,0,i,0,dp[i][0]);
if(dp[i][0]>a[i])
upp(i+1,1,i,0,a[i-1]);
}
if(dp[i][1]<INF)
{
if(a[i-1]>a[i])
upp(i+1,1,i,1,dp[i][1]);
if(dp[i][1]<a[i])
downn(i+1,0,i,1,a[i-1]);
}
}
return dp[n][0]>0 || dp[n][1]<INF;
}
void solve(int x,int y)
{
if(x<=0) return;
ans[y].push_back(a[x-1]);
solve(s[x][y].x,s[x][y].y);
}
int main()
{
int i,j,k;
w(~scanf("%d",&n))
{
mem(s,0);
up(i,0,n-1)
scanf("%d",&a[i]);
if(set1() || set2() || set3())//三个只要有一个符合即可
{
ans[0].clear();
ans[1].clear();
if(dp[n][0]>=1 && dp[n][0]<=n) solve(n,0);
else solve(n,1);
up(i,0,1)
{
if(ans[i].empty())//如果整个序列是递增或者递减,那么取出一个来
{
ans[i].push_back(ans[i^1].back());
ans[i^1].pop_back();
}
}
printf("%d %d\n",ans[0].size(),ans[1].size());
up(i,0,1)
{
down(j,ans[i].size()-1,0)
{
printf("%d",ans[i][j]);
if(j)
printf(" ");
}
printf("\n");
}
}
else
puts("Fail");
}
return 0;
}
原文地址:http://blog.csdn.net/libin56842/article/details/45141875