29 lines
702 B
Text
29 lines
702 B
Text
$ cat << EOT > code/evaluate.py
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
from joblib import load
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
from train import load_data
|
|
|
|
|
|
def main(repo_path):
|
|
test_csv_path = repo_path / "data/test.csv"
|
|
test_data, labels = load_data(test_csv_path)
|
|
model = load(repo_path / "model.joblib")
|
|
predictions = model.predict(test_data)
|
|
accuracy = accuracy_score(labels, predictions)
|
|
metrics = {"accuracy": accuracy}
|
|
print(metrics)
|
|
accuracy_path = repo_path / "accuracy.json"
|
|
accuracy_path.write_text(json.dumps(metrics))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
repo_path = Path(__file__).parent.parent
|
|
main(repo_path)
|
|
EOT
|