GitLab wurde erfolgreich aktualisiert. Durch regelmäßige Updates bleibt das THM GitLab sicher. Danke für Ihre Geduld.

Commit 4694cb4c authored by Jens Plüddemann's avatar Jens Plüddemann

descision tree

parent bee5e4a1
......@@ -15,6 +15,8 @@ tensorflow = "*"
statsmodels = "*"
xlrd = "*"
missingno = "*"
pydotplus = "*"
graphviz = "*"
[requires]
python_version = "3.7"
This diff is collapsed.
from sklearn.datasets import load_iris
from sklearn import tree
import pydotplus
import matplotlib.pyplot as plt
class DescisionTree:
def __init__(self):
# load_iris() gibt ein Dictionary mit mehreren Variablen zurück
iris = load_iris()
# Um auf die Daten zuzugreifen nehmen wir den Key "data"
data = iris["data"]
# Von den Daten interessieren uns nicht alle Werte
# Daten liegen in der Form [[..., ..., ..., ...], [..., ..., ..., ...], ...]
# Also einer Liste mit vielen Listen an jeweils 4 Werten vor
# Mit dem Aufruf [:, 2:] holen wir uns aus jeder einzelnen Liste, der vielen Listen, die letzten 2 Werte
X = data[:, 2:]
# Um auf das Target zuzugreifen nehmen wir den Key "target"
y = iris["target"]
# Erstellen des Baums und Hinzufügen der Daten
tree_iris = tree.DecisionTreeClassifier(criterion='entropy', random_state=0, max_depth=2).fit(X, y)
# Plotten des Baumes [2:] ignoriert dabei die ersten beiden Werte (sepal length, ...)
tree.plot_tree(tree_iris, feature_names=iris["feature_names"][2:], class_names=iris["target_names"],
rounded=True)
# Zeigen des Plots
plt.show()
if __name__ == '__main__':
data = DescisionTree()
......@@ -3,9 +3,9 @@ from sklearn import linear_model
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class StockMarket:
def __init__(self, file_path: str):
# Das Dictionary wurde in eine csv Datei umgewandelt
# Einlesen der CSV Datei, und Umwandeln der Year und Month Spalte zu einer time Spalte
self.df = pd.read_csv(file_path, parse_dates={'time': [1, 2]})
......@@ -17,13 +17,11 @@ class StockMarket:
self.stock_index_price = self.df['Stock_Index_Price']
def look_at_the_data(self):
# Schreiben des Datensatzes in die Konsole
# ProTip: .to_string() verhindert, dass der Datensatz gekürzt dargestellt wird
print(self.df.to_string())
def plot_initial(self):
# Erstellen einer neuen Figur und alle Plots darin zu Plotten
# ist nötig, da wir mehrere Plots in einem Fenster plotten wollen
fig, ax = plt.subplots(2, 1)
......@@ -59,7 +57,6 @@ class StockMarket:
print(f"Prediction: Interest_Rate=2.1, Unemployment_Rate=6.0, Stock_Index_Price = {reg.predict([[2.1, 6.0]])}")
def plot_3d(self):
# Erstellen einer 3d Achse
ax = plt.figure().gca(projection='3d')
......@@ -76,12 +73,10 @@ class StockMarket:
plt.show()
if __name__ == '__main__':
data = StockMarket('../../res/stock_market.csv')
# data.look_at_the_data()
# data.plot_initial()
# data.linear_regression()
data.plot_3d()
\ No newline at end of file
data.plot_3d()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment