##############################################################################
## 
#  \file calibrate.py
#  \brief PZT calibration.
#  \version 0.3
#  \author Dimitri Denk
#  \date 12.08.2011
#
#  This file implements PTZ calibration function for DTI interferometers
#
#  Copyright (c) 2010-2021
#  DENKTECH
#  http://www.denktech.de
#
#  Permission to use, copy, modify, distribute and sell this software
#  and its documentation for any purpose is hereby granted without fee,
#  provided that the above copyright notice appear in all copies and
#  that both that copyright notice and this permission notice appear
#  in supporting documentation. DENKTECH makes no representations about 
#  the suitability of this software for any purpose. 
#  It is provided "as is" without express or implied warranty.
#
#  History
#  12.08.2011 DD first release
#  30.07.2021 DD reworking of code for Python3
#
#############################################################################

##
# \addtogroup DTI_SCRIPTS scripts
# \ingroup DTI_TOOLS
#
# Python scripts for DTI interferometers
#
# @{

import time
from scipy.optimize import brentq
import numpy as np


## Calibrate PZT v2
#
#  This function implements PZT calibration.
#  FFT is used as phase detection algorithm.
#  The calibration refers to vertical interference fringes of flat surface.
#
#  \param port - dtcam object
#  \param umax - normalized maximal PZT voltage, float in range from 0. to 1.
#  \param p - number of acquired points, positive integer
#  \param n - number of measurements, positive integer
#  \param verb - verbose output, boolean
#  \param sve - save intermediate results, boolean
#  \retval res - result of operation, True if success
#  \retval list - normalized calibrated PZT voltages, list of floats in range from 0. to 1.
def calibrate(port: object, umax: float,
	points: int = 10, samples: int = 2, repeats: int = 1,
	verb: bool = True, sve: bool = False) -> np.ndarray:
	csize = .5
	start = time.time()

	if verb:
		print("PZT calibration")
		print("PZT max value", umax)
		print("number of points", points)
		print("number of samples", samples)
		print("number of repeats", repeats)

	ref = []
	for v in np.linspace(0., umax, points):
		for i in range(samples):
			ref.append(v)
	# back movement
	ref += ref[::-1]
	# repeats
	ref *= repeats
	if verb:
		print("Reference points", ref)

	# measurement
	if verb:
		print("Measurement...")
	frames = []
	num = len(ref)
	for shift in ref:
		# capture frame
		f = port.get_frame(shift, verb)
		# cut middle part of image
		xb = int(f.shape[0] * (1. - csize) / 2)
		xe = xb + int(f.shape[0] * csize)
		yb = int(f.shape[1] * (1. - csize) / 2)
		ye = yb + int(f.shape[1] * csize)
		frames.append(f[xb:xe, yb:ye])

	# analysis
	if verb:
		print("Data processing...")
	p = np.linspace(0, 1., num) * 2 * np.pi
	F = []
	for frame in frames:
		F.append(frame.reshape(frame.size))
	
	for j in range(10):
		m = []
		for _ref_ in p:
			m.append([1, np.cos(_ref_), -np.sin(_ref_)])
		m = np.array(m)
		A = np.linalg.pinv(m)
		
		F = np.array(F)
		K = np.dot(A, F)
		
		f = np.arctan2(K[2], K[1])
		
		m = np.array([np.ones(f.size), np.cos(f), -np.sin(f)])
		A = np.linalg.pinv(m).transpose()
		for i in range(num):
			K = np.dot(A, F[i])
			p[i] = np.arctan2(K[2], K[1])

	# unwrap phase
	fp = np.mean(p.reshape(-1, samples), axis=1)
	rp = np.mean(np.array(ref).reshape(-1, samples), axis=1)
	# phase offset correction
	fp -= fp[0]

	for i in range(1, len(fp)):
		while (fp[i] - fp[i - 1]) > np.pi:
			fp[i:] = fp[i:] - 2 * np.pi
		while (fp[i] - fp[i - 1]) < -np.pi:
			fp[i:] = fp[i:] + 2 * np.pi

	if fp[1] < fp[0]:
		fp = - fp

	if verb:
		print("Phase response:")
		for i in range(len(rp)):
			print("shift %.3f, phase %.3f" % (rp[i], fp[i]))

	# calculate polynomial fit
	z = np.polyfit(rp, fp, 5)
	y = np.poly1d(z)

	# find roots
	if verb:
		print("Calculating of shift values...")
	root = []
	for y0 in [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi]:
		if verb:
			print("target phase", y0)
		val = brentq(lambda x: y(x) - y0, -0.1 * umax, umax)
		print("target phase %.3f, calculated shift %.3f" % (y0, val, ))
		root.append(val)
		
	# add offset to avoid 0V at PZT
	root = (np.array(root) - root[0]) + 0.01

	if sve:
		np.save('pzt_ref.npy', root)
	
	time1 = time.time()
	if verb:
		print("Time " + str(int((time1 - start) * 1000 + .5)) + "ms")
		print('Calibrated values: ' + str(root))

	# normalized calibrated PZT voltages
	return root

##
# @}

# EOF
