'''
Created on May 09 2013
Edited on November 12 2015

@author: Mengxue CAO
'''

################################
# This programme is used for trainig stage
################################
import random
import functions
import gsom_initial
import gsom_grow

# define file path
PATH = '/Users/mengxuecao/Codes/GSOM/bernd'

# training data
AUDIO_TRAINING_DATA = PATH+'/training_data/audio_coding_1'

# define variables
VECTOR_DIMENSION = 1368	# vector dimension of the audio training data

# initialize audio map
gmap_a = gsom_initial.initial(VECTOR_DIMENSION)

# read in the audio training data
training_data_a = gsom_grow.read_data(AUDIO_TRAINING_DATA)

print '\nNumber of tokens in the training data: ', len(training_data_a)

# create random sequence
random_order = range(len(training_data_a))


################################################################################################################################
# caliberating phase
# (1) caliberate A-Map
# (2) NS=1
# (3) no grow
# (4) record max error as the GT for growing phase
################################################################################################################################
print '\n######################### Caliberating Phase #########################'

# define number of caliberating phases
EPOCH_caliberating = 2

for i in range(EPOCH_caliberating):
	# files to be written for Audio map
	MATLAB_FILE_A = PATH + '/matlab/gmap_caliberting_a_' + str(i+1)
	GMAP_FILE_A = PATH + '/gmaps/gmap_caliberting_a_' + str(i+1)
	GMAP_RETRAIN_FILE_A = PATH + '/gmaps/gmap_retrain_caliberting_a_' + str(i+1)

	print '### epoch: ', i+1, '###'
	# initialzie error value for audio map
	error_a = 0

	# random sort items in the random sequence
	random.shuffle(random_order)

	for j in range(len(random_order)):
		# initialize Learning Rate: LR(t) for A-map at each input
		LR_a = 0.5 

		# prepare training data for A-map
		token_a = training_data_a[random_order[j]]		# token for A
		training_vector_a_char = []						# list container for training vectors of A in char
		training_vector_a_float = []					# list container for training vectors of A in float
		training_vector_a_char = token_a[0:-1]			# get training vectors of A
		training_lable_a = token_a[-1]					# get lable of A

		# convert the vector from string to float
		for k in range(len(training_vector_a_char)):		
			training_vector_a_float.append(float(training_vector_a_char[k]))
		
		#############################################################################################
		# Train the A map
		#############################################################################################
		print '\nTraining on A-MAP for token ...['+training_lable_a+']...'
	
		# find BMU and get the BMU representation (use winner-take-all in caliberating phase)
		bmu_a = gsom_grow.find_bmu_WTA(gmap_a, training_vector_a_float, error_a, 'euclidean')
		error_a = bmu_a[-1]
	
		# do weight update
		gmap_a = gsom_grow.weight_update_WTA(gmap_a, training_vector_a_float, bmu_a, LR_a)	# weight update
		
	print 'Max error a: ', error_a

	# write files
	functions.write_gmap(gmap_a, GMAP_FILE_A)					# write A gmap into file
	functions.write_gmap_matlab(gmap_a, MATLAB_FILE_A)			# write A gmap information into matlab file
	functions.write_gmap_retrain(gmap_a, GMAP_RETRAIN_FILE_A)	# write A retrain file

################################################################################################################################
# growing phase
# (1) NS=1 if network size < 10, others NS=log(N)/log(10)
# (2) grow
# (3) use max error as the GT 
################################################################################################################################

print '\n######################### Growing Phase #########################'

# set GT value based on caliberating phase
GT_a = error_a

# define distance measurements: 'euclidean' OR 'cosine'
measure_a = 'euclidean'

# define number of growing phases
EPOCH_growing = 15

for n in range(EPOCH_growing):
	print '### epoch: ', n+1, '###'
	
	# get network size
	N_A = len(gmap_a)

	# files to be written
	MATLAB_FILE_A = PATH + '/matlab/gmap_training_a_' + str(n+1)
	GMAP_FILE_A = PATH + '/gmaps/gmap_training_a_' + str(n+1)
	GMAP_RETRAIN_FILE_A = PATH + '/gmaps/gmap_retrain_training_a_' + str(n+1)

	# random sort items in the random sequence
	random.shuffle(random_order)	

	# present a token to the A-map at a time, and update the A-map
	for i in range(len(random_order)):
		# initialize Learning Rate: LR(t) for A-map at each input
		LR_a = 0.5 

		# initialize Neighbourhood Size based on network size
		NS_a = functions.calculate_NS(len(gmap_a))	
		
		# prepare training data for A		
		token_a = training_data_a[random_order[i]]		# token in each step
		training_vector_a_char = []						# list container for training vectors of A in char
		training_vector_a_float = []					# list container for training vectors of A in float
		training_vector_a_char = token_a[0:-1]			# get training vectors of A
		training_lable_a = token_a[-1]					# get lable of A

		# convert the vector from string to float
		for j in range(len(training_vector_a_char)):		
			training_vector_a_float.append(float(training_vector_a_char[j]))

		
		#############################################################################################
		# Train the A map
		#############################################################################################
		print '\nTraining on A-MAP for token ...['+training_lable_a+']...'
		print 'NS is: ', NS_a
		print 'GT is: ', GT_a

		# find BMU and get the BMU representation
		bmu_a = gsom_grow.find_bmu(gmap_a, training_vector_a_float, GT_a, measure_a)

		# if bmu is boundary node and is growable, then grow
		if bmu_a[0]['is_grow'] == 1 and bmu_a[0]['is_boundary'] == 1:
			gmap_a = gsom_grow.grow_nodes(gmap_a, training_vector_a_float, bmu_a)
		
		# else, do weight update or error distribution
		else:
			while NS_a >= 1:
				gmap_a = gsom_grow.weight_update(gmap_a, training_vector_a_float, bmu_a, GT_a, LR_a, NS_a)	# weight update
				print 'Learning rate of A-MAP: ', LR_a
				NS_a = gsom_grow.calculate_neighbourhood_size(NS_a)		# reduce neighbourhood size
				LR_a = gsom_grow.calculate_learning_rate(gmap_a, LR_a)	# reduce learning rate

	
	#########################################################################################################
	# write files for trained A-map
	#########################################################################################################
	print '\nFinal number of nodes in A-MAP:', len(gmap_a)
	functions.write_gmap(gmap_a, GMAP_FILE_A)					# write A gmap into file
	functions.write_gmap_matlab(gmap_a, MATLAB_FILE_A)			# write A gmap information into matlab file
	functions.write_gmap_retrain(gmap_a, GMAP_RETRAIN_FILE_A)	# write retrain file for A-map
	