diff --git a/tracking/management/commands/import_gpx.py b/tracking/management/commands/import_gpx.py index 9e161ed..bf91039 100644 --- a/tracking/management/commands/import_gpx.py +++ b/tracking/management/commands/import_gpx.py @@ -1,8 +1,7 @@ __author__ = 'tyrel' from django.core.management.base import BaseCommand, CommandError from django.contrib.auth.models import User -from tracking.models import Track, Segment, Point -from django.contrib.gis import geos +from tracking.models import Track import xmltodict import os @@ -31,36 +30,7 @@ class Command(BaseCommand): def import_file(self, filename): with open(filename) as fd: - obj = dict(xmltodict.parse(fd.read()))['gpx'] + gpx = dict(xmltodict.parse(fd.read()))['gpx'] user = self.user - create_points_segments_tracks(obj, user) - - -def create_points_segments_tracks(obj, user): - # Get some Meta data from the gpx file. - metadata = obj['metadata'] - start_time = metadata['time'] - # Get the the track and segments from the gpx file. - tracks = obj['trk'] - name = tracks['name'] - segments = tracks['trkseg'] - # Check unimplemented features - if isinstance(segments, list): - raise NotImplemented( - "Haven't run across one that has multiple segments. " - "Contribute a patch if you have.") - elif isinstance(segments, dict): - track = Track.objects.create(user=user, - start=start_time, - name=name) - points = segments['trkpt'] - segment = Segment(track=track, time=points[0]['time']) - segment.save() - for pt in points: - lat, lon = float(pt['@lat']), float(pt['@lon']) - poly_pt = geos.Point(lon, lat) - Point.objects.create(segment=segment, - time=pt['time'], - elevation=pt['ele'], - point=poly_pt) + Track.objects.create_from_gpx(gpx, user) diff --git a/tracking/models.py b/tracking/models.py index 2a6faeb..6427950 100644 --- a/tracking/models.py +++ b/tracking/models.py @@ -1,14 +1,46 @@ from django.contrib.gis.db import models +from django.contrib.gis import geos +from django.db import transaction # Create your models here. from django.contrib.auth.models import User +class TrackManager(models.Manager): + def create_from_gpx(cls, obj, user): + # Get some Meta data from the gpx file. + metadata = obj['metadata'] + start_time = metadata['time'] + # Get the the track and segments from the gpx file. + tracks = obj['trk'] + name = tracks['name'] + segments = tracks['trkseg'] + # Check unimplemented features + + with transaction.atomic(): + track = Track.objects.create(user=user, + start=start_time, + name=name) + points = segments['trkpt'] + segment = Segment(track=track, time=points[0]['time']) + segment.save() + for pt in points: + lat, lon = float(pt['@lat']), float(pt['@lon']) + poly_pt = geos.Point(lon, lat) + Point.objects.create(segment=segment, + time=pt['time'], + elevation=pt['ele'], + point=poly_pt) + return track + raise Exception("Track not created, rolling back DB") + + class Track(models.Model): user = models.ForeignKey(User, blank=False) start = models.DateTimeField() name = models.CharField(max_length=1024, blank=False) description = models.TextField(blank=True) + objects = TrackManager() @property def finish(self): diff --git a/tracking/tests.py b/tracking/tests.py index a49dafe..76bd7ef 100644 --- a/tracking/tests.py +++ b/tracking/tests.py @@ -1,5 +1,6 @@ from django.test import TestCase from django.core import management +from tracking.models import Track from django.core.management.base import CommandError # Create your tests here. @@ -19,6 +20,9 @@ class ImportGPXManagementTest(TestCase): management.call_command('import_gpx', 'tyrel', '/noop/nodir/') def test_import(self): + track_count = Track.objects.all().count() management.call_command('import_gpx', 'tyrel', 'testing/example-ride.gpx') + new_track_count = Track.objects.all().count() + self.assertNotEqual(track_count, new_track_count)