#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <climits>
#include <cstring>
#include <string>
#include <set>
#include <bitset>
#include <map>
#include <queue>
#include <stack>
#include <vector>
#include <cassert>
#include <ctime>
#define rep(i,m,n) for(i=m;i<=(int)n;i++)
#define inf 0x3f3f3f3f
#define mod 1000000007
#define vi vector<int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define ll long long
#define pi acos(-1.0)
#define pii pair<int,int>
#define sys system("pause")
#define ls (rt<<1)
#define rs (rt<<1|1)
#define all(x) x.begin(),x.end()
const int maxn=1e5+10;
const int N=5e4+10;
using namespace std;
ll gcd(ll p,ll q){return q==0?p:gcd(q,p%q);}
ll qmul(ll p,ll q,ll mo){ll f=0;while(q){if(q&1)f=(f+p)%mo;p=(p+p)%mo;q>>=1;}return f;}
ll qpow(ll p,ll q){ll f=1;while(q){if(q&1)f=f*p;p=p*p;q>>=1;}return f;}
int n,m,k,t;
struct samnode{
samnode *son[10] , *f;
int l ;
}*root,*last,sam[maxn*40];
int cnt;
void init(){
root = last = &sam[cnt=0];
}
void add(int x,samnode *last)
{
samnode *p = &sam[++cnt] , *jp=last;
p->l = jp->l+1;
last = p;
for( ; jp&&!jp->son[x] ; jp=jp->f) jp->son[x]=p;
if(!jp) p->f = root;
else{
if(jp->l+1 == jp->son[x]->l) p->f = jp->son[x];
else{
samnode *r = &sam[++cnt] , *q = jp->son[x];
*r = *q;
r->l = jp->l+1;
q->f = p->f = r;
for( ; jp && jp->son[x]==q ; jp=jp->f) jp->son[x]=r;
}
}
}
vi e[maxn];
int du[maxn],c[maxn];
void dfs(int x,int y,samnode *rt)
{
int i;
add(c[x],rt);
rt=rt->son[c[x]];
rep(i,0,e[x].size()-1)
{
int z=e[x][i];
if(z==y)continue;
dfs(z,x,rt);
}
}
int main(){
int i,j;
init();
scanf("%d%d",&n,&m);
rep(i,1,n)scanf("%d",&c[i]);
rep(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
du[x]++,du[y]++;
e[x].pb(y),e[y].pb(x);
}
rep(i,1,n)if(du[i]==1)dfs(i,0,root);
ll ret=0;
rep(i,1,cnt)ret+=sam[i].l-sam[i].f->l;
printf("%lld\n",ret);
return 0;
}