#include "disjset.h"

#include "except.h"
#include <vector>

using namespace std;

// Construct the disjoint sets object.
// numElements is the initial number of disjoint sets.
DisjSet::DisjSet( long numElements ) : s( numElements )
{
    for( vector<long>::size_type i = 0; i < s.size( ); i++ )
    {
        s[ i ] = -1;
    }
}

// Union two disjoint sets.
// root1 is the root of set 1.
// root2 is the root of set 2.
// Throw BadArgumentException if either argument is not a root.
void DisjSet::unionSets( long root1, long root2 )
{
    assertIsRoot( root1 );
    assertIsRoot( root2 );
    if( s[ root2 ] < s[ root1 ] )  // root2 is deeper
    {
        s[ root1 ] = root2;        // Make root2 new root
    }
    else
    {
        if( s[ root1 ] == s[ root2 ] )
        {
            s[ root1 ]--;          // Update height if same
        }
        s[ root2 ] = root1;        // Make root1 new root
    }
    return;
}

// Throw a BadArgumentException if root is not a set root.
void DisjSet::assertIsRoot( long root ) const
{
    assertIsItem( root );
    if( s[ root ] >= 0 )
    {
        throw BadArgumentException( );
    }
    return;
}

// Throw a BadArgumentException if item is not in range.
void DisjSet::assertIsItem( long item ) const
{
    if( item < 0 || item >= (long)(s.size( )) )
    {
        throw BadArgumentException( );
    }
    return;
}

// Perform a find.
// Return the set containing x.
// Throw BadArgumentException if x is out of range.
long DisjSet::find( long x ) const
{
    assertIsItem( x );
    return ( s[ x ] < 0 ) ? x : find( s[ x ] );
}

// Perform a find with path compression.
// Return the set containing x.
// Throw BadArgumentException if x is out of range.
long DisjSet::find( long x )
{
    assertIsItem( x );
    return ( s[ x ] < 0 ) ? x : s[ x ] = find( s[ x ] );
}

