mojo's Blog

The Overfitting Problem 본문

머신러닝

The Overfitting Problem

_mojo_ 2023. 4. 3. 16:58

The Overfitting Problem

 

 

※ Polynomial Curver Fitting

 

Which order polynomial does best fit for the data?

\(f(x) = w_{0} + w_{1}x + ... + w_{M}x^{M} = \sum_{i=0}^{M} w_{i}^{T}x^{i}\)

 

다항식은 M차원으로 표현되며 주어진 data를 얼마나 잘 표현하는지, 주어진 data에 적합한지를 찾는 문제가

있다고 할 때, 과연 몇 차원의 다항식을 활용하는 것이 주어진 data를 가장 잘 표현하는 것인지를 해결해야 한다.

또한 차원이 주어진다면, w 값을 또한 어떻게 선택해낼 것인지도 문제이다.

 

We want to minimize the sum-of-squared error function.

\(E(w) = \sum_{i=1}^{n} (f(x^{(i)}) - y^{(i)})^{2}\)

 

model 이 error 를 최소화하는지를 확인하는 error function 은 위와 같다.

 

 

※ Example: Model Comparisons

 

Which model is better?

 

 

위 사진을 보면 M = 1 인 경우가 M = 0 인 경우보다 error 값이 더 작은 것을 확인할 수 있다.

 

 

그렇다면 M = 3 인 경우와 M = 9 인 경우에 대해 어떤 모델이 더 좋을지 생각해본다면,

error function 관점에서는 확실히 M = 9 인 경우가 좋다.

하지만 이 문제에 대해 어떤 모양을 띠고 있는지 알고 있기 때문에, 정답 분포 관점에서는

M = 3 인 경우가 더 좋다고 볼 수 있다.

 

As the order (M) increases,

  - The complexity of model increases.

 

As the complexity of model increases,

  - The model can more exactly learn the given data.

  - The prediction accuaracy does not necessarily increase.

 

다항식 차수가 커질수록 model 의 복잡도가 커진다.

따라서 complexity 가 증가할수록 데이터를 표현하는데 더 정확하게 학습할 수 있다.

그러나 너무 복잡도가 커진다면 error 값이 0에 가까워지겠지만, 제일 잘 맞춘다고 말할 수는 없다.

 

 

※ Overfitting Problem

 

For M = 9, the training error is zero.

  - The polynomial contains 10 degrees of freedom corresponding to 10 parameters,

     so we can be fixed exactly to the 10 data points.

 

However, the test error has become very large. Why?

 

가지고 있지 않은 test data에 대해서 고차원 모델을 적용하게 될 경우 예상치 못한 error 값이

커질 수 있게 된다.

이러한 현상을 Overfitting Problem 이라고 부른다.

 

As M increases, the magnitude of coefficients gets larger.

  - For M = 9, the coefficients have become finely tuned to the training data.

  - Between data points, the function exhibits large oscillations.

 

M = 9 인 경우에 대해 지나치게 복잡한 모델을 학습한 경우라고 볼 수 있다.

즉, test data 에 대해 적합하게 학습이 이루어졌다라고 얘기할 수 없다.

 

 

※ Overfitting vs. Generalization

 

What is the purpose of machine learning?

  (1) Learning the given data as exactly as possible

  (2) Predict the unknown data as exactly as possible based on the given data

 

머신러닝 근본적인 목적은 주어진 data가 아닌, unkown data 를 잘 예측하는 것이다.

하지만 현실적으로 보이지 않은 data를 잘 맞춘다는 것은 어렵다.

따라서 training data를 가지고 모델을 잘 학습한다면 아마도 test data에 대해서 잘 맞출것이라는 

생각을 할 수 있다.

 

How to Generalize our Models

 

※ What is Genearlization?

 

Expect the model to generalize if it explains the data well given the complexity of the model.

 

 

training error 을 줄이는 것이 아닌, unknown data 에 대한 error 를 줄이는 것으로

최적의 상태에 도달하도록 학습하는 것이 모델의 generalization 이라고 볼 수 있다.

 

 

※ How to Achieve Generalization?

 

The goal is to achieve good generalization by making accurate predictions for test data.

  - Choosing the values of parameters that minimize the error function on the training data

     may not be the best option.

 

We would like to model the true regularities in the data and ignore the noise in the data.

  - It is difficult to know which regularities are real and which are accidental among training samples.

 

단순히 error function 을 최소화하도록 파라미터를 선택하는 것이 최선의 방향이라고 볼 수 없다.

즉, generalization error 측면에서는 최고의 선택이라고 보기 어렵다.

그리고 model이 training data가 아닌 test data 에서 잘 예측하기 위해 일종의 어떤 제약조건을

두는 것이 필요할 수 있다.

 

 

※ Increasing the Size of Data

 

For a given model complexity, the overfitting problem becomes less severe

as the data size increases.

The number of parameters is not necessarily the most appropriate measure

of the model complexity.

 

첫 번째 해결책은 주어진 data 수를 늘리는 방법이 있다.

모델의 complexity 를 고정한 상태에서 data 사이즈를 늘린다면 자연스럽게 복잡하게 표현된

data는 수 많은 data, 추가적으로 표현된 data에 의해 overfitting error 가 완화된다고 볼 수 있다.

 

 

※ Penalizing the Model Complexity

 

\(\hat{\theta} = \underset{\theta}{argmin} \frac{1}{n} \sum_{i=1}^{n} \zeta (f(x^{(i)}), y^{(i)}) + \lambda \Omega (\theta)\)
\(\zeta (f(x^{(i)}), y^{(i)})\) : Fit the data
\(\lambda \Omega (\theta)\) : Penalize complex models (Regularization parameter)

 

두 번째 해결책은 model의 복잡도가 높아지면 그것에 대한 penalty를 주는 term을 추가하는 것이다.

그러면 패널티를 부여하는 오메가 함수에 대해 먼저 알아보도록 한다.

 

 

※ Common Regularization Functions

 

Lasso regression (L1-Reg)

\(\Omega_{Lasso} (\theta) = \sum_{i=1}^{d}|\theta_{i}|\)

 

- Encourage sparsity by setting weight = 0.

  => Used to select the most informative features.

- Does not have an analytic solution

   => numerical methods.

 

첫 번째 방식은 Lasso regression 이다.

절댓값의 합으로 표현되며 불필요한 부분에 가중치 0을 부여함으로써 훨씬 더 sparse 한 형태로

주어진 선형 model 을 표현하는 특징을 갖는다.

그리고 Lasso 함수는 미분하여 0이 되는 지점을 찾기 어렵기 때문에, numerical method 를

통해 해결해야 한다.

 

Ridge regression (L2-Reg)

\(\Omega_{Ridge} (\theta) = \sum_{i=1}^{d}\theta_{i}^{2}\)

 

- Does not encourage sparsity

  => small but non-zero weights.

- Distributes weight across related features (robust).

- Analytic solution (easy to compute)

 

두 번째 방식은 Ridge regression 이다.

제곱의 합으로 표현되며 미분하여 0이 되는 지점을 쉽게 계산할 수 있어서 analytic solution 을 활용한다.

그리고 가중치 값이 작지만 0의 값을 갖지 않는다는 특징을 가진다. 

 

 

※ Adopting the Weight Decay

 

One technique for controlling overfitting problem is regularization,

which amounts to adding a penalty term to the error function.

  - Shrinking to zero : penalize coefficients based on their size.

  - For a penalty function which is the sum of the squares of the parameters,

     this is known as "weight decay", or "ridge regression".

\( E(w) = \sum_{i=1}^{n} (f(x^{(i)}) - y^{(i)})^{2} \)
\( \downarrow \)
\(E(w) = \sum_{i=1}^{n} (f(x^{(i)}) - y^{(i)})^{2} + \lambda |w|^{2}\)

 

위와 같이 L2 ridge regression 방식을 사용하는 경우가 대표적이고,

\(\lambda |w|^{2}\) 을 붙이는 regularization 방식을 Weight Decay 라고 부른다.

 

 

※ Example: Adopting the Weight Decay

 

Training and test errors vs. regularization for the M = 9 polynomial

Small \(lambda\) vs. Large \(lambda\)

 

 

오른쪽 사진을 보면 \(\lambda\) 값이 되게 작은 경우에 대해서 overfitting 문제가 발생한다.

그리고 \(\lambda \) 값이 점차 커지면 test data와 training data와의 에러 차이가 점차 작아지는

것을 확인할 수 있다.

즉, \(ln\lambda\) 값이 -18 일 때 왼쪽 사진과 같이 최적의 모양꼴로 표현되는 것을 알 수 있다.

\(\lambda\) 값이 더 커져서 test, training 데이터에 해당하는 error 값이 점차적으로 증가하게 되는데

model 이 지나치게 단순화되고 있다라고 볼 수 있고 underfitting 문제라고 볼 수 있다.

 

'머신러닝' 카테고리의 다른 글

Cross-Validation  (0) 2023.04.04
Multinomial Logistic Regression  (0) 2023.03.29
The Concept of Logistic Regression  (0) 2023.03.29
Parameter Estimation  (0) 2023.03.21
Classification Problem  (0) 2023.03.19
Comments