Ideal interpolation kernel is sinc function, which consists of all frequencies upto Nyquist frequency. However, this involves infinite-dimension convolution in discrete calculation, so we apply window function. In signal procession, thus, interpolation kernel is windowed-sinc function [https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf], and the window function can vary to achieve (1) to narrow the window width to reduce calculation, (2) and in the same time to increase accuracy and reduce artifacts. We have to satisfy these two contradictory conditions, so we have to compromise and optimise conditions.
The kernel generally must satisfy f(0) = 1 and f(n != 0 but integer) = 0 and additionally I force to satisfy to be continuous function, to avoid catastrophic artifacts. Generic window function may be written in Fourier series, (cosine if symmetric) [https://en.wikipedia.org/wiki/Window_function]
with absolute constraints
assuming the window width to be [-1, 1] (to use in practice, scale this in
Projection of the window parameters to the constraints.
From
Let us concretise the projection idea more. The
The two perpendicular vectors is the same as the coefficients of the equations of the planes, so the 2D kernel including a point
To obtain the intersection of this kernel to the two constraints
From these equations, we have to obtain
More simplified,
Thus, the projected point
Thus, here we may define projection operator
Gradient vector projection to the constraints
Now, as doing projection each iteration of training is inefficient, we want to develop to project each step only on the constraints.
From the original valid point
Thus, here we define the new projection operator
Pseudocodes
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class window_param_projection(nn.Module):
def __init__(self, N: int):
super(window_param_projection, self).__init__()
self.N = N
self.odd_kernel = torch.zeros(N+1, dtype=torch.float32)
self.even_kernel = torch.zeros(N+1, dtype=torch.float32)
self.odd_kernel[1::2] = 1
self.even_kernel[0::2] = 1
self.proj_coeffs = N+1
if N % 2 != 0: # odd
self.proj_coeffs = torch.ones(self.proj_coeffs) / self.proj_coeffs
else:
self.proj_coeffs = (self.proj_coeffs - self.even_kernel + self.odd_kernel) / (torch.square(self.proj_coeffs) - 1)
def forward(a: torch.Tensor, is_differential_vector: bool):sum_odd_a = a.matmul(self.odd_kernel)
sum_even_a = a.matmul(self.even_kernel)
sum_mosaic_a = torch.zeros(a.shape)
sum_mosaic_a[...,0::2] = sum_even_a
sum_mosaic_a[...,1::2] = sum_odd_a
vector = -2 * sum_mosaic_a
if not is_differential_vector:
vector = vector + 1
delta_a = vector.matmul(self.proj_coeffs)
return a + delta_a
class window_function(nn.Module):
def __init__(self, N: int, a: torch.Tensor, b: torch.Tensor):
self.counting_array = torch.linspace(0, N, N+1)
self.a = a # must be nn.Parameter
self.b = b # must be nn.Parameter
self.b[..., 0] = 0
def forward(t: torch.Tensor):
args = t.view(list(x.shape) + [1]).matmul(self.counting_array) * math.pi
cos = torch.cos(args)
sin = torch.sin(args)
return cos.matmul(self.a) + sin.matmul(self.b)
No comments:
Post a Comment