14 lines
474 B
Text
14 lines
474 B
Text
$ git diff
|
|
diff --git a/code/train.py b/code/train.py
|
|
index 3b309e1..017a6bf 100644
|
|
--- a/code/train.py
|
|
+++ b/code/train.py
|
|
@@ -39,7 +39,7 @@ def load_data(data_path):
|
|
def main(repo_path):
|
|
train_csv_path = repo_path / "data/train.csv"
|
|
train_data, labels = load_data(train_csv_path)
|
|
- sgd = SGDClassifier(max_iter=10)
|
|
+ sgd = SGDClassifier(max_iter=100)
|
|
trained_model = sgd.fit(train_data, labels)
|
|
dump(trained_model, repo_path / "model.joblib")
|
|
|